Commit b03f56db authored by danyao12's avatar danyao12
Browse files

add SetA/BBlockStartWindow in BlockwiseGemmXdlops_v2

parent fced127d
...@@ -859,24 +859,19 @@ struct BlockwiseGemmXdlops_v2 ...@@ -859,24 +859,19 @@ struct BlockwiseGemmXdlops_v2
"wrong!"); "wrong!");
} }
__host__ __device__ BlockwiseGemmXdlops_v2(index_t switch_flag, __host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other)
Tuple4 b_origin = CalculateBThreadOriginDataIndex(), : a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin)
Tuple4 a_origin = CalculateAThreadOriginDataIndex())
: switch_flag_(switch_flag), a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{ {
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), }
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, __device__ void SetABlockStartWindow(Tuple4 a_origin = CalculateAThreadOriginDataIndex())
"wrong!"); {
a_thread_copy_.SetSrcCoord(a_origin);
} }
__host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other) __device__ void SetBBlockStartWindow(Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin)
{ {
b_thread_copy_.SetSrcCoord(b_origin);
} }
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
...@@ -1141,7 +1136,6 @@ struct BlockwiseGemmXdlops_v2 ...@@ -1141,7 +1136,6 @@ struct BlockwiseGemmXdlops_v2
B_K1, B_K1,
B_K1>; B_K1>;
index_t switch_flag_;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_; BThreadCopy b_thread_copy_;
}; };
......
...@@ -1333,8 +1333,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1333,8 +1333,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
typename Gemm1::ABlockwiseCopy{tensor_operation::element_wise::PassThrough{}}; typename Gemm1::ABlockwiseCopy{tensor_operation::element_wise::PassThrough{}};
// dQ: blockwise gemm // dQ: blockwise gemm
auto qgrad_blockwise_gemm = auto qgrad_blockwise_gemm = typename Gemm1::BlockwiseGemm{};
typename Gemm1::BlockwiseGemm{make_tuple(0, 0, 0, 0), make_tuple(0, 0, 0, 0)}; qgrad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0));
// dQ: B matrix blockwise copy // dQ: B matrix blockwise copy
auto k_thread_origin = qgrad_blockwise_gemm.CalculateBThreadOriginDataIndex(); auto k_thread_origin = qgrad_blockwise_gemm.CalculateBThreadOriginDataIndex();
...@@ -1458,7 +1458,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1458,7 +1458,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
// dV: blockwise gemm // dV: blockwise gemm
auto v_slash_k_grad_blockwise_gemm = typename Gemm2::BlockwiseGemm{1, make_tuple(0, 0, 0, 0)}; auto v_slash_k_grad_blockwise_gemm = typename Gemm2::BlockwiseGemm{};
v_slash_k_grad_blockwise_gemm.SetBBlockStartWindow(make_tuple(0, 0, 0, 0));
auto q_slash_ygrad_thread_origin = v_slash_k_grad_blockwise_gemm.CalculateBThreadOriginDataIndex(); auto q_slash_ygrad_thread_origin = v_slash_k_grad_blockwise_gemm.CalculateBThreadOriginDataIndex();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment