"examples/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "815dce0554793d0788faf4eaacf0c7271c070e95"
Commit 73f0c21b authored by danyao12's avatar danyao12
Browse files

Merge branch 'attn-train-develop-qloop' into attn-train-develop-qloop-light

parents 227076ba 4d18cd84
...@@ -40,7 +40,7 @@ __global__ void ...@@ -40,7 +40,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1( kernel_grouped_multihead_attention_backward_kloop_xdl_cshuffle_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -254,7 +254,7 @@ template <index_t NumDimG, ...@@ -254,7 +254,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -266,7 +266,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -266,7 +266,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1; using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1;
struct ProblemDesc struct ProblemDesc
{ {
std::vector<index_t> a_gs_ms_ks_lengths; std::vector<index_t> a_gs_ms_ks_lengths;
...@@ -956,7 +956,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -956,7 +956,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1< const auto kernel =
kernel_grouped_multihead_attention_backward_kloop_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
GroupKernelArg, GroupKernelArg,
AElementwiseOperation, AElementwiseOperation,
...@@ -1209,7 +1210,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1209,7 +1210,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1" str << "DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -40,7 +40,7 @@ __global__ void ...@@ -40,7 +40,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2( kernel_grouped_multihead_attention_backward_kloop_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -254,7 +254,7 @@ template <index_t NumDimG, ...@@ -254,7 +254,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -266,7 +266,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -266,7 +266,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2; using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2;
struct ProblemDesc struct ProblemDesc
{ {
std::vector<index_t> a_gs_ms_ks_lengths; std::vector<index_t> a_gs_ms_ks_lengths;
...@@ -948,7 +948,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -948,7 +948,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2< const auto kernel =
kernel_grouped_multihead_attention_backward_kloop_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
GroupKernelArg, GroupKernelArg,
AElementwiseOperation, AElementwiseOperation,
...@@ -1200,7 +1201,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1200,7 +1201,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2" str << "DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -40,7 +40,7 @@ __global__ void ...@@ -40,7 +40,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1( kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -251,7 +251,7 @@ template <index_t NumDimG, ...@@ -251,7 +251,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -263,7 +263,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -263,7 +263,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1; using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1;
struct ProblemDesc struct ProblemDesc
{ {
std::vector<index_t> a_gs_ms_ks_lengths; std::vector<index_t> a_gs_ms_ks_lengths;
...@@ -961,7 +961,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -961,7 +961,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1< const auto kernel =
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
GroupKernelArg, GroupKernelArg,
AElementwiseOperation, AElementwiseOperation,
...@@ -1207,7 +1208,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1207,7 +1208,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1" str << "DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -40,7 +40,7 @@ __global__ void ...@@ -40,7 +40,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2( kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -258,7 +258,7 @@ template <index_t NumDimG, ...@@ -258,7 +258,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -270,7 +270,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -270,7 +270,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2; using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2;
struct ProblemDesc struct ProblemDesc
{ {
std::vector<index_t> a_gs_ms_ks_lengths; std::vector<index_t> a_gs_ms_ks_lengths;
...@@ -968,7 +968,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -968,7 +968,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2< const auto kernel =
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
GroupKernelArg, GroupKernelArg,
AElementwiseOperation, AElementwiseOperation,
...@@ -1219,7 +1220,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1219,7 +1220,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2" str << "DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -39,7 +39,7 @@ __global__ void ...@@ -39,7 +39,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2( kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -250,7 +250,7 @@ template <index_t NumDimG, ...@@ -250,7 +250,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
: public DeviceGroupedMultiheadAttentionForward<NumDimG, : public DeviceGroupedMultiheadAttentionForward<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -290,7 +290,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -290,7 +290,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN; static constexpr index_t NumDimGemm1K = NumDimN;
#endif #endif
using DeviceOp = DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle; using DeviceOp = DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1;
using ProblemDesc = typename DeviceGroupedMultiheadAttentionForward<NumDimG, using ProblemDesc = typename DeviceGroupedMultiheadAttentionForward<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -813,7 +813,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -813,7 +813,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
auto launch_kernel = auto launch_kernel =
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) { [&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
const auto kernel = const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm, kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1<GridwiseGemm,
GemmAccDataType, GemmAccDataType,
GroupKernelArg, GroupKernelArg,
AElementwiseOperation, AElementwiseOperation,
...@@ -1123,7 +1123,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1123,7 +1123,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle" str << "DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -256,7 +256,7 @@ template <index_t NumDimG, ...@@ -256,7 +256,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
: public DeviceGroupedMultiheadAttentionForward<NumDimG, : public DeviceGroupedMultiheadAttentionForward<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -296,7 +296,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -296,7 +296,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN; static constexpr index_t NumDimGemm1K = NumDimN;
#endif #endif
using DeviceOp = DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle; using DeviceOp = DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2;
using ProblemDesc = typename DeviceGroupedMultiheadAttentionForward<NumDimG, using ProblemDesc = typename DeviceGroupedMultiheadAttentionForward<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -1145,7 +1145,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1145,7 +1145,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle" str << "DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -2114,6 +2114,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2114,6 +2114,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
sgrad_slice_idx[I3], sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3)); sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run( qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
...@@ -2125,8 +2126,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2125,8 +2126,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
gemm2_a_block_buf); gemm2_a_block_buf);
} }
// block_sync_lds(); // sync before write
qgrad_gemm_tile_k_blockwise_copy.Run(Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3, qgrad_gemm_tile_k_blockwise_copy.Run(Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
k_block_buf, k_block_buf,
Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3, Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
......
...@@ -2044,6 +2044,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2044,6 +2044,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
sgrad_slice_idx[I3], sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3)); sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run( qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
...@@ -2060,7 +2061,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2060,7 +2061,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1, qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1,
Gemm2::b_block_slice_copy_step); Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm2::b_block_desc_k0_n_k1, qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm2::b_block_desc_k0_n_k1,
gemm2_b_block_buf); gemm2_b_block_buf);
......
...@@ -2191,6 +2191,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2191,6 +2191,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
sgrad_slice_idx[I3], sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3)); sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run( qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
...@@ -2202,8 +2203,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2202,8 +2203,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
gemm2_a_block_buf); gemm2_a_block_buf);
} }
// block_sync_lds(); // sync before write
qgrad_gemm_tile_k_blockwise_copy.Run(Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3, qgrad_gemm_tile_k_blockwise_copy.Run(Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
k_block_buf, k_block_buf,
Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3, Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
......
...@@ -2142,6 +2142,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2142,6 +2142,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
sgrad_slice_idx[I3], sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3)); sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range)) if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{ {
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run( qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
...@@ -2158,7 +2159,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2158,7 +2159,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1, qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1,
Gemm2::b_block_slice_copy_step); Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm2::b_block_desc_k0_n_k1, qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm2::b_block_desc_k0_n_k1,
gemm2_b_block_buf); gemm2_b_block_buf);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment