Commit 73f0c21b authored by danyao12's avatar danyao12
Browse files

Merge branch 'attn-train-develop-qloop' into attn-train-develop-qloop-light

parents 227076ba 4d18cd84
...@@ -103,7 +103,7 @@ static constexpr bool Deterministic = false; ...@@ -103,7 +103,7 @@ static constexpr bool Deterministic = false;
// If 64 < DIM <= 128, ues prototype2 2nd template. // If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -173,7 +173,7 @@ using DeviceGemmInstance = ...@@ -173,7 +173,7 @@ using DeviceGemmInstance =
Deterministic>; Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -243,7 +243,7 @@ using DeviceGemmInstance = ...@@ -243,7 +243,7 @@ using DeviceGemmInstance =
Deterministic>; Deterministic>;
// using DeviceGemmInstance = // using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
// NumDimG, // NumDimG,
// NumDimM, // NumDimM,
// NumDimN, // NumDimN,
...@@ -313,7 +313,7 @@ using DeviceGemmInstance = ...@@ -313,7 +313,7 @@ using DeviceGemmInstance =
// Deterministic>; // Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
......
...@@ -103,7 +103,7 @@ static constexpr bool Deterministic = false; ...@@ -103,7 +103,7 @@ static constexpr bool Deterministic = false;
// If 64 < DIM <= 128, ues prototype2 2nd template. // If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -173,7 +173,7 @@ using DeviceGemmInstance = ...@@ -173,7 +173,7 @@ using DeviceGemmInstance =
Deterministic>; Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
......
...@@ -79,7 +79,7 @@ static constexpr bool Deterministic = false; ...@@ -79,7 +79,7 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -150,7 +150,7 @@ using DeviceGemmInstance = ...@@ -150,7 +150,7 @@ using DeviceGemmInstance =
Deterministic>; Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -221,7 +221,7 @@ using DeviceGemmInstance = ...@@ -221,7 +221,7 @@ using DeviceGemmInstance =
Deterministic>; Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
......
...@@ -79,7 +79,7 @@ static constexpr bool Deterministic = false; ...@@ -79,7 +79,7 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -150,7 +150,7 @@ using DeviceGemmInstance = ...@@ -150,7 +150,7 @@ using DeviceGemmInstance =
Deterministic>; Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -221,7 +221,7 @@ using DeviceGemmInstance = ...@@ -221,7 +221,7 @@ using DeviceGemmInstance =
Deterministic>; Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
......
...@@ -112,7 +112,7 @@ static constexpr bool Deterministic = false; ...@@ -112,7 +112,7 @@ static constexpr bool Deterministic = false;
// If 64 < DIM <= 128, ues prototype2 2nd template. // If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -183,7 +183,7 @@ using DeviceGemmInstanceFWD = ...@@ -183,7 +183,7 @@ using DeviceGemmInstanceFWD =
Deterministic>; Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -253,7 +253,7 @@ using DeviceGemmInstanceBWD = ...@@ -253,7 +253,7 @@ using DeviceGemmInstanceBWD =
Deterministic>; Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -324,7 +324,7 @@ using DeviceGemmInstanceFWD = ...@@ -324,7 +324,7 @@ using DeviceGemmInstanceFWD =
Deterministic>; Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -394,7 +394,7 @@ using DeviceGemmInstanceBWD = ...@@ -394,7 +394,7 @@ using DeviceGemmInstanceBWD =
Deterministic>; Deterministic>;
// using DeviceGemmInstanceBWD = // using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
// NumDimG, // NumDimG,
// NumDimM, // NumDimM,
// NumDimN, // NumDimN,
...@@ -464,7 +464,7 @@ using DeviceGemmInstanceBWD = ...@@ -464,7 +464,7 @@ using DeviceGemmInstanceBWD =
// Deterministic>; // Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -535,7 +535,7 @@ using DeviceGemmInstanceFWD = ...@@ -535,7 +535,7 @@ using DeviceGemmInstanceFWD =
Deterministic>; Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
......
...@@ -102,7 +102,7 @@ static constexpr bool Deterministic = false; ...@@ -102,7 +102,7 @@ static constexpr bool Deterministic = false;
// If 64 < DIM <= 128, ues prototype2 2nd template. // If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -172,7 +172,7 @@ using DeviceGemmInstance = ...@@ -172,7 +172,7 @@ using DeviceGemmInstance =
Deterministic>; Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -242,7 +242,7 @@ using DeviceGemmInstance = ...@@ -242,7 +242,7 @@ using DeviceGemmInstance =
Deterministic>; Deterministic>;
// using DeviceGemmInstance = // using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2< // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
// NumDimG, // NumDimG,
// NumDimM, // NumDimM,
// NumDimN, // NumDimN,
...@@ -312,7 +312,7 @@ using DeviceGemmInstance = ...@@ -312,7 +312,7 @@ using DeviceGemmInstance =
// Deterministic>; // Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
......
...@@ -79,7 +79,7 @@ static constexpr bool Deterministic = true; ...@@ -79,7 +79,7 @@ static constexpr bool Deterministic = true;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -150,7 +150,7 @@ using DeviceGemmInstance = ...@@ -150,7 +150,7 @@ using DeviceGemmInstance =
Deterministic>; Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -221,7 +221,7 @@ using DeviceGemmInstance = ...@@ -221,7 +221,7 @@ using DeviceGemmInstance =
Deterministic>; Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
......
...@@ -79,7 +79,7 @@ static constexpr bool Deterministic = true; ...@@ -79,7 +79,7 @@ static constexpr bool Deterministic = true;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -150,7 +150,7 @@ using DeviceGemmInstance = ...@@ -150,7 +150,7 @@ using DeviceGemmInstance =
Deterministic>; Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -221,7 +221,7 @@ using DeviceGemmInstance = ...@@ -221,7 +221,7 @@ using DeviceGemmInstance =
Deterministic>; Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
......
...@@ -111,7 +111,7 @@ static constexpr bool Deterministic = true; ...@@ -111,7 +111,7 @@ static constexpr bool Deterministic = true;
// If 64 < DIM <= 128, ues prototype2 2nd template. // If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -182,7 +182,7 @@ using DeviceGemmInstanceFWD = ...@@ -182,7 +182,7 @@ using DeviceGemmInstanceFWD =
Deterministic>; Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -252,7 +252,7 @@ using DeviceGemmInstanceBWD = ...@@ -252,7 +252,7 @@ using DeviceGemmInstanceBWD =
Deterministic>; Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -323,7 +323,7 @@ using DeviceGemmInstanceFWD = ...@@ -323,7 +323,7 @@ using DeviceGemmInstanceFWD =
Deterministic>; Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -393,7 +393,7 @@ using DeviceGemmInstanceBWD = ...@@ -393,7 +393,7 @@ using DeviceGemmInstanceBWD =
Deterministic>; Deterministic>;
// using DeviceGemmInstanceBWD = // using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2< // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
// NumDimG, // NumDimG,
// NumDimM, // NumDimM,
// NumDimN, // NumDimN,
...@@ -463,7 +463,7 @@ using DeviceGemmInstanceBWD = ...@@ -463,7 +463,7 @@ using DeviceGemmInstanceBWD =
// Deterministic>; // Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -534,7 +534,7 @@ using DeviceGemmInstanceFWD = ...@@ -534,7 +534,7 @@ using DeviceGemmInstanceFWD =
Deterministic>; Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
......
...@@ -54,7 +54,7 @@ __global__ void ...@@ -54,7 +54,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1( kernel_batched_multihead_attention_backward_kloop_xdl_cshuffle_v1(
const InputDataType* __restrict__ p_a_grid, const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid, const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
...@@ -277,7 +277,7 @@ template <index_t NumDimG, ...@@ -277,7 +277,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -299,7 +299,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -299,7 +299,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
static constexpr index_t NumDimGemm1K = NumDimN; static constexpr index_t NumDimGemm1K = NumDimN;
#endif #endif
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -949,30 +949,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -949,30 +949,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1< const auto kernel =
GridwiseGemm, kernel_batched_multihead_attention_backward_kloop_xdl_cshuffle_v1<
InputDataType, GridwiseGemm,
OutputDataType, InputDataType,
ZDataType, OutputDataType,
LSEDataType, ZDataType,
AElementwiseOperation, LSEDataType,
BElementwiseOperation, AElementwiseOperation,
AccElementwiseOperation, BElementwiseOperation,
B1ElementwiseOperation, AccElementwiseOperation,
CElementwiseOperation, B1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, CElementwiseOperation,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::AGridDesc_AK0_M_AK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, DeviceOp::B1GridDesc_BK0_N_BK1,
DeviceOp::LSEGridDesc_M, typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::VGradGridDesc_N_O, DeviceOp::LSEGridDesc_M,
DeviceOp::YGradGridDesc_O0_M_O1, DeviceOp::VGradGridDesc_N_O,
typename GridwiseGemm::DefaultBlock2CTileMap, DeviceOp::YGradGridDesc_O0_M_O1,
ComputeBasePtrOfStridedBatch, typename GridwiseGemm::DefaultBlock2CTileMap,
C0MatrixMask, ComputeBasePtrOfStridedBatch,
has_main_k_block_loop_, C0MatrixMask,
Deterministic>; has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
...@@ -1284,7 +1285,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1284,7 +1285,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1" str << "DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -53,7 +53,7 @@ __global__ void ...@@ -53,7 +53,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2( kernel_batched_multihead_attention_backward_kloop_xdl_cshuffle_v2(
const InputDataType* __restrict__ p_a_grid, const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid, const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
...@@ -276,7 +276,7 @@ template <index_t NumDimG, ...@@ -276,7 +276,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -298,7 +298,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -298,7 +298,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static constexpr index_t NumDimGemm1K = NumDimN; static constexpr index_t NumDimGemm1K = NumDimN;
#endif #endif
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -951,30 +951,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -951,30 +951,31 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2< const auto kernel =
GridwiseGemm, kernel_batched_multihead_attention_backward_kloop_xdl_cshuffle_v2<
InputDataType, GridwiseGemm,
OutputDataType, InputDataType,
ZDataType, OutputDataType,
LSEDataType, ZDataType,
AElementwiseOperation, LSEDataType,
BElementwiseOperation, AElementwiseOperation,
AccElementwiseOperation, BElementwiseOperation,
B1ElementwiseOperation, AccElementwiseOperation,
CElementwiseOperation, B1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, CElementwiseOperation,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::AGridDesc_AK0_M_AK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, DeviceOp::B1GridDesc_BK0_N_BK1,
DeviceOp::LSEGridDesc_M, typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::VGradGridDesc_N_O, DeviceOp::LSEGridDesc_M,
DeviceOp::YGradGridDesc_M0_O_M1, DeviceOp::VGradGridDesc_N_O,
typename GridwiseGemm::DefaultBlock2CTileMap, DeviceOp::YGradGridDesc_M0_O_M1,
ComputeBasePtrOfStridedBatch, typename GridwiseGemm::DefaultBlock2CTileMap,
C0MatrixMask, ComputeBasePtrOfStridedBatch,
has_main_k_block_loop_, C0MatrixMask,
Deterministic>; has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
...@@ -1284,7 +1285,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1284,7 +1285,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2" str << "DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -53,7 +53,7 @@ __global__ void ...@@ -53,7 +53,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1( kernel_batched_multihead_attention_backward_qloop_phased_xdl_cshuffle_v1(
const InputDataType* __restrict__ p_a_grid, const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid, const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
...@@ -273,7 +273,7 @@ template <index_t NumDimG, ...@@ -273,7 +273,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 struct DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -295,7 +295,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -295,7 +295,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
static constexpr index_t NumDimGemm1K = NumDimN; static constexpr index_t NumDimGemm1K = NumDimN;
#endif #endif
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -945,29 +945,30 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -945,29 +945,30 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1< const auto kernel =
GridwiseGemm, kernel_batched_multihead_attention_backward_qloop_phased_xdl_cshuffle_v1<
InputDataType, GridwiseGemm,
OutputDataType, InputDataType,
ZDataType, OutputDataType,
LSEDataType, ZDataType,
AElementwiseOperation, LSEDataType,
BElementwiseOperation, AElementwiseOperation,
AccElementwiseOperation, BElementwiseOperation,
B1ElementwiseOperation, AccElementwiseOperation,
CElementwiseOperation, B1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, CElementwiseOperation,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::AGridDesc_AK0_M_AK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, DeviceOp::B1GridDesc_BK0_N_BK1,
DeviceOp::LSEGridDesc_M, typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::YGradGridDesc_O0_M_O1, DeviceOp::LSEGridDesc_M,
typename GridwiseGemm::DefaultBlock2CTileMap, DeviceOp::YGradGridDesc_O0_M_O1,
ComputeBasePtrOfStridedBatch, typename GridwiseGemm::DefaultBlock2CTileMap,
C0MatrixMask, ComputeBasePtrOfStridedBatch,
has_main_k_block_loop_, C0MatrixMask,
Deterministic>; has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
...@@ -1277,7 +1278,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1277,7 +1278,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1" str << "DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -53,7 +53,7 @@ __global__ void ...@@ -53,7 +53,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1( kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1(
const InputDataType* __restrict__ p_a_grid, const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid, const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
...@@ -273,7 +273,7 @@ template <index_t NumDimG, ...@@ -273,7 +273,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -285,7 +285,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -285,7 +285,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -933,29 +933,30 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -933,29 +933,30 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1< const auto kernel =
GridwiseGemm, kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v1<
InputDataType, GridwiseGemm,
OutputDataType, InputDataType,
ZDataType, OutputDataType,
LSEDataType, ZDataType,
AElementwiseOperation, LSEDataType,
BElementwiseOperation, AElementwiseOperation,
AccElementwiseOperation, BElementwiseOperation,
B1ElementwiseOperation, AccElementwiseOperation,
CElementwiseOperation, B1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, CElementwiseOperation,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::AGridDesc_AK0_M_AK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, DeviceOp::B1GridDesc_BK0_N_BK1,
DeviceOp::LSEGridDesc_M, typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::YGradGridDesc_O0_M_O1, DeviceOp::LSEGridDesc_M,
typename GridwiseGemm::DefaultBlock2CTileMap, DeviceOp::YGradGridDesc_O0_M_O1,
ComputeBasePtrOfStridedBatch, typename GridwiseGemm::DefaultBlock2CTileMap,
C0MatrixMask, ComputeBasePtrOfStridedBatch,
has_main_k_block_loop_, C0MatrixMask,
Deterministic>; has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
...@@ -1248,7 +1249,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1248,7 +1249,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1" str << "DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -52,7 +52,7 @@ __global__ void ...@@ -52,7 +52,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2( kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v2(
const InputDataType* __restrict__ p_a_grid, const InputDataType* __restrict__ p_a_grid,
const InputDataType* __restrict__ p_b_grid, const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
...@@ -279,7 +279,7 @@ template <index_t NumDimG, ...@@ -279,7 +279,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -291,7 +291,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -291,7 +291,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -950,29 +950,30 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -950,29 +950,30 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2< const auto kernel =
GridwiseGemm, kernel_batched_multihead_attention_backward_qloop_xdl_cshuffle_v2<
InputDataType, GridwiseGemm,
OutputDataType, InputDataType,
ZDataType, OutputDataType,
LSEDataType, ZDataType,
AElementwiseOperation, LSEDataType,
BElementwiseOperation, AElementwiseOperation,
AccElementwiseOperation, BElementwiseOperation,
B1ElementwiseOperation, AccElementwiseOperation,
CElementwiseOperation, B1ElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, CElementwiseOperation,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::AGridDesc_AK0_M_AK1,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3,
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, DeviceOp::B1GridDesc_BK0_N_BK1,
DeviceOp::LSEGridDesc_M, typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
DeviceOp::YGradGridDesc_M0_O_M1, DeviceOp::LSEGridDesc_M,
typename GridwiseGemm::DefaultBlock2CTileMap, DeviceOp::YGradGridDesc_M0_O_M1,
ComputeBasePtrOfStridedBatch, typename GridwiseGemm::DefaultBlock2CTileMap,
C0MatrixMask, ComputeBasePtrOfStridedBatch,
has_main_k_block_loop_, C0MatrixMask,
Deterministic>; has_main_k_block_loop_,
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
...@@ -1279,7 +1280,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1279,7 +1280,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2" str << "DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -51,7 +51,7 @@ __global__ void ...@@ -51,7 +51,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_multiheadattention_forward_xdl_cshuffle( kernel_batched_multiheadattention_forward_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
...@@ -255,7 +255,7 @@ template <index_t NumDimG, ...@@ -255,7 +255,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
: public DeviceBatchedMultiheadAttentionForward<NumDimG, : public DeviceBatchedMultiheadAttentionForward<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -295,7 +295,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -295,7 +295,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN; static constexpr index_t NumDimGemm1K = NumDimN;
#endif #endif
using DeviceOp = DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle; using DeviceOp = DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -746,7 +746,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -746,7 +746,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
auto launch_kernel = auto launch_kernel =
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) { [&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle< const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
...@@ -1116,7 +1116,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1116,7 +1116,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle" str << "DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V1"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -51,7 +51,7 @@ __global__ void ...@@ -51,7 +51,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_batched_multiheadattention_forward_xdl_cshuffle( kernel_batched_multiheadattention_forward_xdl_cshuffle_v2(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid, const FloatAB* __restrict__ p_b1_grid,
...@@ -263,7 +263,7 @@ template <index_t NumDimG, ...@@ -263,7 +263,7 @@ template <index_t NumDimG,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic, bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
: public DeviceBatchedMultiheadAttentionForward<NumDimG, : public DeviceBatchedMultiheadAttentionForward<NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -303,7 +303,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -303,7 +303,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN; static constexpr index_t NumDimGemm1K = NumDimN;
#endif #endif
using DeviceOp = DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle; using DeviceOp = DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -761,7 +761,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -761,7 +761,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
auto launch_kernel = auto launch_kernel =
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) { [&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle< const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
...@@ -1133,7 +1133,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1133,7 +1133,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle" str << "DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment