Commit 73f0c21b authored by danyao12's avatar danyao12
Browse files

Merge branch 'attn-train-develop-qloop' into attn-train-develop-qloop-light

parents 227076ba 4d18cd84
......@@ -40,7 +40,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1(
kernel_grouped_multihead_attention_backward_kloop_xdl_cshuffle_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
......@@ -254,7 +254,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -266,7 +266,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1;
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1;
struct ProblemDesc
{
std::vector<index_t> a_gs_ms_ks_lengths;
......@@ -956,16 +956,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
Deterministic>;
const auto kernel =
kernel_grouped_multihead_attention_backward_kloop_xdl_cshuffle_v1<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel(
stream_config,
......@@ -1209,7 +1210,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1"
str << "DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -40,7 +40,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2(
kernel_grouped_multihead_attention_backward_kloop_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
......@@ -254,7 +254,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -266,7 +266,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2;
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2;
struct ProblemDesc
{
std::vector<index_t> a_gs_ms_ks_lengths;
......@@ -948,16 +948,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
Deterministic>;
const auto kernel =
kernel_grouped_multihead_attention_backward_kloop_xdl_cshuffle_v2<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel(
stream_config,
......@@ -1200,7 +1201,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2"
str << "DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -40,7 +40,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1(
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
......@@ -251,7 +251,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -263,7 +263,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1;
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1;
struct ProblemDesc
{
std::vector<index_t> a_gs_ms_ks_lengths;
......@@ -961,16 +961,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
Deterministic>;
const auto kernel =
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v1<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel(
stream_config,
......@@ -1207,7 +1208,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1"
str << "DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -40,7 +40,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2(
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
......@@ -258,7 +258,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -270,7 +270,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2;
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2;
struct ProblemDesc
{
std::vector<index_t> a_gs_ms_ks_lengths;
......@@ -968,16 +968,17 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
Deterministic>;
const auto kernel =
kernel_grouped_multihead_attention_backward_qloop_xdl_cshuffle_v2<
GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel(
stream_config,
......@@ -1219,7 +1220,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2"
str << "DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -39,7 +39,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2(
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
......@@ -250,7 +250,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
: public DeviceGroupedMultiheadAttentionForward<NumDimG,
NumDimM,
NumDimN,
......@@ -290,7 +290,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle;
using DeviceOp = DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1;
using ProblemDesc = typename DeviceGroupedMultiheadAttentionForward<NumDimG,
NumDimM,
NumDimN,
......@@ -813,7 +813,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
auto launch_kernel =
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2<GridwiseGemm,
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1<GridwiseGemm,
GemmAccDataType,
GroupKernelArg,
AElementwiseOperation,
......@@ -1123,7 +1123,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle"
str << "DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -256,7 +256,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
: public DeviceGroupedMultiheadAttentionForward<NumDimG,
NumDimM,
NumDimN,
......@@ -296,7 +296,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle;
using DeviceOp = DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2;
using ProblemDesc = typename DeviceGroupedMultiheadAttentionForward<NumDimG,
NumDimM,
NumDimN,
......@@ -1145,7 +1145,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle"
str << "DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -2114,6 +2114,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
......@@ -2125,8 +2126,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
gemm2_a_block_buf);
}
// block_sync_lds(); // sync before write
qgrad_gemm_tile_k_blockwise_copy.Run(Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
k_block_buf,
Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
......
......@@ -2044,6 +2044,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
......@@ -2060,7 +2061,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1,
Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm2::b_block_desc_k0_n_k1,
gemm2_b_block_buf);
......
......@@ -2191,6 +2191,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
......@@ -2202,8 +2203,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
gemm2_a_block_buf);
}
// block_sync_lds(); // sync before write
qgrad_gemm_tile_k_blockwise_copy.Run(Gemm2::b_block_desc_n0_n1_n2_k0_k1_k2_k3,
k_block_buf,
Gemm2::b_thread_desc_n0_n1_n2_k0_k1_k2_k3,
......
......@@ -2142,6 +2142,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
sgrad_slice_idx[I3],
sgrad_slice_idx[I3] + Gemm2Params::ABlockSliceLengths_M0_K0_M1_K1::At(I3));
block_sync_lds(); // sync before write
if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
{
qgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
......@@ -2158,7 +2159,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
qgrad_gemm_tile_k_blockwise_copy.MoveSrcSliceWindow(k_grid_desc_n0_k_n1,
Gemm2::b_block_slice_copy_step);
block_sync_lds(); // sync before write
qgrad_gemm_tile_k_blockwise_copy.RunWrite(Gemm2::b_block_desc_k0_n_k1,
gemm2_b_block_buf);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment