Commit 1de7de06 authored by danyao12's avatar danyao12
Browse files

attn bwd kernel prototype1

parent 7409bc5d
......@@ -859,6 +859,21 @@ struct BlockwiseGemmXdlops_v2
"wrong!");
}
__host__ __device__ BlockwiseGemmXdlops_v2(index_t switch_flag,
Tuple4 b_origin = CalculateBThreadOriginDataIndex(),
Tuple4 a_origin = CalculateAThreadOriginDataIndex())
: switch_flag_(switch_flag), a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
}
__host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other)
: a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin)
{
......@@ -1126,6 +1141,7 @@ struct BlockwiseGemmXdlops_v2
B_K1,
B_K1>;
index_t switch_flag_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment