Commit 73f0c21b authored by danyao12's avatar danyao12
Browse files

Merge branch 'attn-train-develop-qloop' into attn-train-develop-qloop-light

parents 227076ba 4d18cd84
......@@ -103,7 +103,7 @@ static constexpr bool Deterministic = false;
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -173,7 +173,7 @@ using DeviceGemmInstance =
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -243,7 +243,7 @@ using DeviceGemmInstance =
Deterministic>;
// using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
// NumDimG,
// NumDimM,
// NumDimN,
......@@ -313,7 +313,7 @@ using DeviceGemmInstance =
// Deterministic>;
#elif(DIM <= 128)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
......
......@@ -103,7 +103,7 @@ static constexpr bool Deterministic = false;
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -173,7 +173,7 @@ using DeviceGemmInstance =
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......
......@@ -79,7 +79,7 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -150,7 +150,7 @@ using DeviceGemmInstance =
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -221,7 +221,7 @@ using DeviceGemmInstance =
Deterministic>;
#elif(DIM <= 128)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......
......@@ -79,7 +79,7 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
......@@ -150,7 +150,7 @@ using DeviceGemmInstance =
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
......@@ -221,7 +221,7 @@ using DeviceGemmInstance =
Deterministic>;
#elif(DIM <= 128)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
......
......@@ -112,7 +112,7 @@ static constexpr bool Deterministic = false;
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -183,7 +183,7 @@ using DeviceGemmInstanceFWD =
Deterministic>;
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -253,7 +253,7 @@ using DeviceGemmInstanceBWD =
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -324,7 +324,7 @@ using DeviceGemmInstanceFWD =
Deterministic>;
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -394,7 +394,7 @@ using DeviceGemmInstanceBWD =
Deterministic>;
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
// NumDimG,
// NumDimM,
// NumDimN,
......@@ -464,7 +464,7 @@ using DeviceGemmInstanceBWD =
// Deterministic>;
#elif(DIM <= 128)
using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -535,7 +535,7 @@ using DeviceGemmInstanceFWD =
Deterministic>;
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
......
......@@ -102,7 +102,7 @@ static constexpr bool Deterministic = false;
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -172,7 +172,7 @@ using DeviceGemmInstance =
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -242,7 +242,7 @@ using DeviceGemmInstance =
Deterministic>;
// using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
// NumDimG,
// NumDimM,
// NumDimN,
......@@ -312,7 +312,7 @@ using DeviceGemmInstance =
// Deterministic>;
#elif(DIM <= 128)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
......
......@@ -79,7 +79,7 @@ static constexpr bool Deterministic = true;
#if(DIM <= 32)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -150,7 +150,7 @@ using DeviceGemmInstance =
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -221,7 +221,7 @@ using DeviceGemmInstance =
Deterministic>;
#elif(DIM <= 128)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......
......@@ -79,7 +79,7 @@ static constexpr bool Deterministic = true;
#if(DIM <= 32)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
......@@ -150,7 +150,7 @@ using DeviceGemmInstance =
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
......@@ -221,7 +221,7 @@ using DeviceGemmInstance =
Deterministic>;
#elif(DIM <= 128)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
......
......@@ -111,7 +111,7 @@ static constexpr bool Deterministic = true;
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -182,7 +182,7 @@ using DeviceGemmInstanceFWD =
Deterministic>;
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -252,7 +252,7 @@ using DeviceGemmInstanceBWD =
Deterministic>;
#elif(DIM <= 64)
using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -323,7 +323,7 @@ using DeviceGemmInstanceFWD =
Deterministic>;
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -393,7 +393,7 @@ using DeviceGemmInstanceBWD =
Deterministic>;
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
// NumDimG,
// NumDimM,
// NumDimN,
......@@ -463,7 +463,7 @@ using DeviceGemmInstanceBWD =
// Deterministic>;
#elif(DIM <= 128)
using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -534,7 +534,7 @@ using DeviceGemmInstanceFWD =
Deterministic>;
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
......
......@@ -54,7 +54,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1(
kernel_batched_multihead_attention_backward_kloop_xdl_cshuffle_v1(
const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
......@@ -277,7 +277,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -299,7 +299,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1;
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -949,7 +949,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1<
const auto kernel =
kernel_batched_multihead_attention_backward_kloop_xdl_cshuffle_v1<
GridwiseGemm,
InputDataType,
OutputDataType,
......@@ -1284,7 +1285,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1"
str << "DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -53,7 +53,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2(
kernel_batched_multihead_attention_backward_kloop_xdl_cshuffle_v2(
const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
......@@ -276,7 +276,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -298,7 +298,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2;
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -951,7 +951,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2<
const auto kernel =
kernel_batched_multihead_attention_backward_kloop_xdl_cshuffle_v2<
GridwiseGemm,
InputDataType,
OutputDataType,
......@@ -1284,7 +1285,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2"
str << "DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -53,7 +53,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1(
kernel_batched_multihead_attention_backward_qloop_phased_xdl_cshuffle_v1(
const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
......@@ -273,7 +273,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -295,7 +295,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1;
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -945,7 +945,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1<
const auto kernel =
kernel_batched_multihead_attention_backward_qloop_phased_xdl_cshuffle_v1<
GridwiseGemm,
InputDataType,
OutputDataType,
......@@ -1277,7 +1278,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1"
str << "DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -53,7 +53,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1(
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1(
const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
......@@ -273,7 +273,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -285,7 +285,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1;
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -933,7 +933,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1<
const auto kernel =
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1<
GridwiseGemm,
InputDataType,
OutputDataType,
......@@ -1248,7 +1249,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1"
str << "DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -52,7 +52,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2(
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v2(
const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
......@@ -279,7 +279,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -291,7 +291,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2;
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -950,7 +950,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2<
const auto kernel =
kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v2<
GridwiseGemm,
InputDataType,
OutputDataType,
......@@ -1279,7 +1280,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2"
str << "DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -51,7 +51,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_multiheadattention_forward_xdl_cshuffle(
kernel_batched_multiheadattention_forward_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
......@@ -255,7 +255,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
: public DeviceBatchedMultiheadAttentionForward<NumDimG,
NumDimM,
NumDimN,
......@@ -295,7 +295,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle;
using DeviceOp = DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -746,7 +746,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
auto launch_kernel =
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle<
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
......@@ -1116,7 +1116,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle"
str << "DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -51,7 +51,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_multiheadattention_forward_xdl_cshuffle(
kernel_batched_multiheadattention_forward_xdl_cshuffle_v2(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
......@@ -263,7 +263,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
: public DeviceBatchedMultiheadAttentionForward<NumDimG,
NumDimM,
NumDimN,
......@@ -303,7 +303,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle;
using DeviceOp = DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -761,7 +761,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
auto launch_kernel =
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle<
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle_v2<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
......@@ -1133,7 +1133,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle"
str << "DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment