Commit b03f56db authored by danyao12's avatar danyao12
Browse files

add SetA/BBlockStartWindow in BlockwiseGemmXdlops_v2

parent fced127d
......@@ -859,24 +859,19 @@ struct BlockwiseGemmXdlops_v2
"wrong!");
}
__host__ __device__ BlockwiseGemmXdlops_v2(index_t switch_flag,
Tuple4 b_origin = CalculateBThreadOriginDataIndex(),
Tuple4 a_origin = CalculateAThreadOriginDataIndex())
: switch_flag_(switch_flag), a_thread_copy_(a_origin), b_thread_copy_(b_origin)
__host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other)
: a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
}
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
__device__ void SetABlockStartWindow(Tuple4 a_origin = CalculateAThreadOriginDataIndex())
{
a_thread_copy_.SetSrcCoord(a_origin);
}
__host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other)
: a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin)
__device__ void SetBBlockStartWindow(Tuple4 b_origin = CalculateBThreadOriginDataIndex())
{
b_thread_copy_.SetSrcCoord(b_origin);
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
......@@ -1141,7 +1136,6 @@ struct BlockwiseGemmXdlops_v2
B_K1,
B_K1>;
index_t switch_flag_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
......
......@@ -1333,8 +1333,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
typename Gemm1::ABlockwiseCopy{tensor_operation::element_wise::PassThrough{}};
// dQ: blockwise gemm
auto qgrad_blockwise_gemm =
typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0), make_tuple(0, 0, 0, 0)};
auto qgrad_blockwise_gemm = typename Gemm1::BlockwiseGemm{};
qgrad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0));
// dQ: B matrix blockwise copy
auto k_thread_origin = qgrad_blockwise_gemm.CalculateBThreadOriginDataIndex();
......@@ -1458,7 +1458,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
tensor_operation::element_wise::PassThrough{}};
// dV: blockwise gemm
auto v_slash_k_grad_blockwise_gemm = typename Gemm2::BlockwiseGemm{1, make_tuple(0, 0, 0, 0)};
auto v_slash_k_grad_blockwise_gemm = typename Gemm2::BlockwiseGemm{};
v_slash_k_grad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0));
auto q_slash_ygrad_thread_origin = v_slash_k_grad_blockwise_gemm.CalculateBThreadOriginDataIndex();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment