Commit 99ebfeba authored by danyao12's avatar danyao12
Browse files

correct deterministic mode

parent 84a81ae2
...@@ -71,10 +71,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -71,10 +71,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -145,7 +146,8 @@ using DeviceGemmInstance = ...@@ -145,7 +146,8 @@ using DeviceGemmInstance =
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
...@@ -215,7 +217,8 @@ using DeviceGemmInstance = ...@@ -215,7 +217,8 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
...@@ -285,7 +288,8 @@ using DeviceGemmInstance = ...@@ -285,7 +288,8 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: DataType in, AccDataType out // Ref Gemm0: DataType in, AccDataType out
......
...@@ -100,10 +100,11 @@ static constexpr auto MaskingSpec = ...@@ -100,10 +100,11 @@ static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
#endif #endif
static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
// DIM should be a multiple of 8. // DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template. // If DIM <= 32 , ues prototype1 1st template.
...@@ -178,7 +179,8 @@ using DeviceGemmInstanceFWD = ...@@ -178,7 +179,8 @@ using DeviceGemmInstanceFWD =
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
...@@ -247,7 +249,8 @@ using DeviceGemmInstanceBWD = ...@@ -247,7 +249,8 @@ using DeviceGemmInstanceBWD =
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
...@@ -317,7 +320,8 @@ using DeviceGemmInstanceFWD = ...@@ -317,7 +320,8 @@ using DeviceGemmInstanceFWD =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
...@@ -386,7 +390,8 @@ using DeviceGemmInstanceBWD = ...@@ -386,7 +390,8 @@ using DeviceGemmInstanceBWD =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
// using DeviceGemmInstanceBWD = // using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...@@ -455,7 +460,8 @@ using DeviceGemmInstanceBWD = ...@@ -455,7 +460,8 @@ using DeviceGemmInstanceBWD =
// 2, // CShuffleNXdlPerWavePerShuffle // 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock // S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec>; // MaskingSpec,
// Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
...@@ -525,7 +531,8 @@ using DeviceGemmInstanceFWD = ...@@ -525,7 +531,8 @@ using DeviceGemmInstanceFWD =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...@@ -594,7 +601,8 @@ using DeviceGemmInstanceBWD = ...@@ -594,7 +601,8 @@ using DeviceGemmInstanceBWD =
4, // CShuffleNXdlPerWavePerShuffle 4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: S = alpha * Q * K^T // Ref Gemm0: S = alpha * Q * K^T
......
...@@ -71,10 +71,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio ...@@ -71,10 +71,11 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio
static constexpr auto MaskingSpec = static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
#if(DIM <= 32) #if(DIM <= 32)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -145,7 +146,8 @@ using DeviceGemmInstance = ...@@ -145,7 +146,8 @@ using DeviceGemmInstance =
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
...@@ -215,7 +217,8 @@ using DeviceGemmInstance = ...@@ -215,7 +217,8 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
...@@ -285,7 +288,8 @@ using DeviceGemmInstance = ...@@ -285,7 +288,8 @@ using DeviceGemmInstance =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: DataType in, AccDataType out // Ref Gemm0: DataType in, AccDataType out
......
...@@ -99,10 +99,11 @@ static constexpr auto MaskingSpec = ...@@ -99,10 +99,11 @@ static constexpr auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
#endif #endif
static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecQ = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr bool Deterministic = true;
// DIM should be a multiple of 8. // DIM should be a multiple of 8.
// If DIM <= 32 , ues prototype1 1st template. // If DIM <= 32 , ues prototype1 1st template.
...@@ -177,7 +178,8 @@ using DeviceGemmInstanceFWD = ...@@ -177,7 +178,8 @@ using DeviceGemmInstanceFWD =
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1<
...@@ -246,7 +248,8 @@ using DeviceGemmInstanceBWD = ...@@ -246,7 +248,8 @@ using DeviceGemmInstanceBWD =
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
...@@ -316,7 +319,8 @@ using DeviceGemmInstanceFWD = ...@@ -316,7 +319,8 @@ using DeviceGemmInstanceFWD =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1<
...@@ -385,7 +389,8 @@ using DeviceGemmInstanceBWD = ...@@ -385,7 +389,8 @@ using DeviceGemmInstanceBWD =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
// using DeviceGemmInstanceBWD = // using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2< // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...@@ -454,7 +459,8 @@ using DeviceGemmInstanceBWD = ...@@ -454,7 +459,8 @@ using DeviceGemmInstanceBWD =
// 2, // CShuffleNXdlPerWavePerShuffle // 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock // S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec>; // MaskingSpec,
// Deterministic>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
...@@ -524,7 +530,8 @@ using DeviceGemmInstanceFWD = ...@@ -524,7 +530,8 @@ using DeviceGemmInstanceFWD =
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...@@ -593,7 +600,8 @@ using DeviceGemmInstanceBWD = ...@@ -593,7 +600,8 @@ using DeviceGemmInstanceBWD =
4, // CShuffleNXdlPerWavePerShuffle 4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec, // MaskingSpecialization
Deterministic>;
#endif #endif
// Ref Gemm0: S = alpha * Q * K^T // Ref Gemm0: S = alpha * Q * K^T
......
...@@ -82,6 +82,7 @@ __global__ void ...@@ -82,6 +82,7 @@ __global__ void
const YGradGridDesc_O0_M_O1 ygrad_grid_desc_o0_m_o1, const YGradGridDesc_O0_M_O1 ygrad_grid_desc_o0_m_o1,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const index_t mblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const float p_drop, const float p_drop,
...@@ -115,40 +116,38 @@ __global__ void ...@@ -115,40 +116,38 @@ __global__ void
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
for(index_t i = 0; i < num_blocks_per_batch; i++) for(index_t i = 0; i < mblock; i++)
{ {
if(get_block_1d_id() % num_blocks_per_batch == i) GridwiseGemm::template Run<HasMainKBlockLoop>(
{ p_a_grid + a_batch_offset,
GridwiseGemm::template Run<HasMainKBlockLoop>( p_b_grid + b_batch_offset,
p_a_grid + a_batch_offset, z_matrix_ptr,
p_b_grid + b_batch_offset, p_b1_grid + b1_batch_offset,
z_matrix_ptr, p_c_grid + c_batch_offset,
p_b1_grid + b1_batch_offset, p_lse_grid + lse_batch_offset,
p_c_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_lse_grid + lse_batch_offset, p_qgrad_grid + a_batch_offset,
p_ygrad_grid + c_batch_offset, p_kgrad_grid + b_batch_offset,
p_qgrad_grid + a_batch_offset, p_vgrad_grid + b1_batch_offset,
p_kgrad_grid + b_batch_offset, p_shared,
p_vgrad_grid + b1_batch_offset, a_element_op,
p_shared, b_element_op,
a_element_op, acc_element_op,
b_element_op, b1_element_op,
acc_element_op, c_element_op,
b1_element_op, a_grid_desc_ak0_m_ak1,
c_element_op, b_grid_desc_bk0_n_bk1,
a_grid_desc_ak0_m_ak1, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, c_grid_desc_mblock_mperblock_nblock_nperblock,
b1_grid_desc_bk0_n_bk1, lse_grid_desc_m,
c_grid_desc_mblock_mperblock_nblock_nperblock, vgrad_grid_desc_n_o,
lse_grid_desc_m, ygrad_grid_desc_o0_m_o1,
vgrad_grid_desc_n_o, block_2_ctile_map,
ygrad_grid_desc_o0_m_o1, c0_matrix_mask,
block_2_ctile_map, p_drop,
c0_matrix_mask, ph,
p_drop, i);
ph);
}
} }
} }
else else
...@@ -180,7 +179,8 @@ __global__ void ...@@ -180,7 +179,8 @@ __global__ void
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
p_drop, p_drop,
ph); ph,
0);
} }
#else #else
ignore = p_a_grid; ignore = p_a_grid;
...@@ -707,7 +707,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -707,7 +707,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
Deterministic>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -941,7 +942,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -941,7 +942,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_) * arg.batch_count_; (Deterministic ? 1
: arg.block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_)) *
arg.batch_count_;
float ave_time = 0; float ave_time = 0;
...@@ -971,41 +974,43 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -971,41 +974,43 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
has_main_k_block_loop_, has_main_k_block_loop_,
Deterministic>; Deterministic>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(
kernel, stream_config,
dim3(grid_size), kernel,
dim3(BlockSize), dim3(grid_size),
0, dim3(BlockSize),
arg.p_a_grid_, 0,
arg.p_b_grid_, arg.p_a_grid_,
arg.p_z_grid_, arg.p_b_grid_,
arg.p_b1_grid_, arg.p_z_grid_,
arg.p_c_grid_, arg.p_b1_grid_,
arg.p_lse_grid_, arg.p_c_grid_,
arg.p_ygrad_grid_, arg.p_lse_grid_,
arg.p_qgrad_grid_, arg.p_ygrad_grid_,
arg.p_kgrad_grid_, arg.p_qgrad_grid_,
arg.p_vgrad_grid_, arg.p_kgrad_grid_,
arg.a_element_op_, arg.p_vgrad_grid_,
arg.b_element_op_, arg.a_element_op_,
arg.acc_element_op_, arg.b_element_op_,
arg.b1_element_op_, arg.acc_element_op_,
arg.c_element_op_, arg.b1_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.c_element_op_,
arg.b_grid_desc_bk0_n_bk1_, arg.a_grid_desc_ak0_m_ak1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_, arg.b1_grid_desc_bk0_n_bk1_,
arg.lse_grid_desc_m_, arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.vgrad_grid_desc_n_o_, arg.lse_grid_desc_m_,
arg.ygrad_grid_desc_o0_m_o1_, arg.vgrad_grid_desc_n_o_,
arg.block_2_ctile_map_, arg.ygrad_grid_desc_o0_m_o1_,
arg.batch_count_, arg.block_2_ctile_map_,
arg.compute_base_ptr_of_batch_, arg.batch_count_,
arg.c0_matrix_mask_, arg.block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_),
arg.p_drop_, arg.compute_base_ptr_of_batch_,
arg.seed_, arg.c0_matrix_mask_,
arg.offset_); arg.p_drop_,
arg.seed_,
arg.offset_);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......
...@@ -81,6 +81,7 @@ __global__ void ...@@ -81,6 +81,7 @@ __global__ void
const YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1, const YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const index_t mblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const float p_drop, const float p_drop,
...@@ -114,40 +115,38 @@ __global__ void ...@@ -114,40 +115,38 @@ __global__ void
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
for(index_t i = 0; i < num_blocks_per_batch; i++) for(index_t i = 0; i < mblock; i++)
{ {
if(get_block_1d_id() % num_blocks_per_batch == i) GridwiseGemm::template Run<HasMainKBlockLoop>(
{ p_a_grid + a_batch_offset,
GridwiseGemm::template Run<HasMainKBlockLoop>( p_b_grid + b_batch_offset,
p_a_grid + a_batch_offset, z_matrix_ptr,
p_b_grid + b_batch_offset, p_b1_grid + b1_batch_offset,
z_matrix_ptr, p_c_grid + c_batch_offset,
p_b1_grid + b1_batch_offset, p_lse_grid + lse_batch_offset,
p_c_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_lse_grid + lse_batch_offset, p_qgrad_grid + a_batch_offset,
p_ygrad_grid + c_batch_offset, p_kgrad_grid + b_batch_offset,
p_qgrad_grid + a_batch_offset, p_vgrad_grid + b1_batch_offset,
p_kgrad_grid + b_batch_offset, p_shared,
p_vgrad_grid + b1_batch_offset, a_element_op,
p_shared, b_element_op,
a_element_op, acc_element_op,
b_element_op, b1_element_op,
acc_element_op, c_element_op,
b1_element_op, a_grid_desc_ak0_m_ak1,
c_element_op, b_grid_desc_bk0_n_bk1,
a_grid_desc_ak0_m_ak1, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
b_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, c_grid_desc_mblock_mperblock_nblock_nperblock,
b1_grid_desc_bk0_n_bk1, lse_grid_desc_m,
c_grid_desc_mblock_mperblock_nblock_nperblock, vgrad_grid_desc_n_o,
lse_grid_desc_m, ygrad_grid_desc_m0_o_m1,
vgrad_grid_desc_n_o, block_2_ctile_map,
ygrad_grid_desc_m0_o_m1, c0_matrix_mask,
block_2_ctile_map, p_drop,
c0_matrix_mask, ph,
p_drop, i);
ph);
}
} }
} }
else else
...@@ -179,7 +178,8 @@ __global__ void ...@@ -179,7 +178,8 @@ __global__ void
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
p_drop, p_drop,
ph); ph,
0);
} }
#else #else
ignore = p_a_grid; ignore = p_a_grid;
...@@ -706,7 +706,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -706,7 +706,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
Deterministic>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -939,7 +940,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -939,7 +940,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_) * arg.batch_count_; (Deterministic ? 1
: arg.block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_)) *
arg.batch_count_;
// Gemm0_K // Gemm0_K
const auto K = const auto K =
...@@ -973,41 +976,43 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -973,41 +976,43 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
has_main_k_block_loop_, has_main_k_block_loop_,
Deterministic>; Deterministic>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(
kernel, stream_config,
dim3(grid_size), kernel,
dim3(BlockSize), dim3(grid_size),
0, dim3(BlockSize),
arg.p_a_grid_, 0,
arg.p_b_grid_, arg.p_a_grid_,
arg.p_z_grid_, arg.p_b_grid_,
arg.p_b1_grid_, arg.p_z_grid_,
arg.p_c_grid_, arg.p_b1_grid_,
arg.p_lse_grid_, arg.p_c_grid_,
arg.p_ygrad_grid_, arg.p_lse_grid_,
arg.p_qgrad_grid_, arg.p_ygrad_grid_,
arg.p_kgrad_grid_, arg.p_qgrad_grid_,
arg.p_vgrad_grid_, arg.p_kgrad_grid_,
arg.a_element_op_, arg.p_vgrad_grid_,
arg.b_element_op_, arg.a_element_op_,
arg.acc_element_op_, arg.b_element_op_,
arg.b1_element_op_, arg.acc_element_op_,
arg.c_element_op_, arg.b1_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.c_element_op_,
arg.b_grid_desc_bk0_n_bk1_, arg.a_grid_desc_ak0_m_ak1_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_, arg.b1_grid_desc_bk0_n_bk1_,
arg.lse_grid_desc_m_, arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.vgrad_grid_desc_n_o_, arg.lse_grid_desc_m_,
arg.ygrad_grid_desc_m0_o_m1_, arg.vgrad_grid_desc_n_o_,
arg.block_2_ctile_map_, arg.ygrad_grid_desc_m0_o_m1_,
arg.batch_count_, arg.block_2_ctile_map_,
arg.compute_base_ptr_of_batch_, arg.batch_count_,
arg.c0_matrix_mask_, arg.block_2_ctile_map_.CalculateGridSize(arg.y_grid_desc_m_o_),
arg.p_drop_, arg.compute_base_ptr_of_batch_,
arg.seed_, arg.c0_matrix_mask_,
arg.offset_); arg.p_drop_,
arg.seed_,
arg.offset_);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......
...@@ -45,7 +45,8 @@ template <typename GridwiseGemm, ...@@ -45,7 +45,8 @@ template <typename GridwiseGemm,
typename C0MatrixMask, typename C0MatrixMask,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout, bool IsDropout,
bool IsLseStoring> bool IsLseStoring,
bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -72,6 +73,7 @@ __global__ void ...@@ -72,6 +73,7 @@ __global__ void
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const index_t mblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
const ushort p_dropout_in_16bits, const ushort p_dropout_in_16bits,
...@@ -101,30 +103,65 @@ __global__ void ...@@ -101,30 +103,65 @@ __global__ void
const index_t global_thread_id = get_thread_global_1d_id(); const index_t global_thread_id = get_thread_global_1d_id();
ck::philox ph(seed, global_thread_id, offset); ck::philox ph(seed, global_thread_id, offset);
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( if constexpr(Deterministic)
p_a_grid + a_batch_offset, {
p_b_grid + b_batch_offset, for(index_t i = 0; i < mblock; i++)
p_b1_grid + b1_batch_offset, {
p_c_grid + c_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
nullptr ? nullptr : p_z_grid + z_batch_offset, p_a_grid + a_batch_offset,
nullptr ? nullptr : p_lse_grid + lse_batch_offset, p_b_grid + b_batch_offset,
p_shared, p_b1_grid + b1_batch_offset,
a_element_op, p_c_grid + c_batch_offset,
b_element_op, nullptr ? nullptr : p_z_grid + z_batch_offset,
acc_element_op, nullptr ? nullptr : p_lse_grid + lse_batch_offset,
b1_element_op, p_shared,
c_element_op, a_element_op,
a_grid_desc_ak0_m_ak1, b_element_op,
b_grid_desc_bk0_n_bk1, acc_element_op,
b1_grid_desc_bk0_n_bk1, b1_element_op,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_element_op,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, a_grid_desc_ak0_m_ak1,
lse_grid_desc_m, b_grid_desc_bk0_n_bk1,
block_2_ctile_map, b1_grid_desc_bk0_n_bk1,
c0_matrix_mask, c_grid_desc_mblock_mperblock_nblock_nperblock,
p_dropout_in_16bits, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
p_dropout_rescale, lse_grid_desc_m,
ph); block_2_ctile_map,
c0_matrix_mask,
p_dropout_in_16bits,
p_dropout_rescale,
ph,
i);
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_b1_grid + b1_batch_offset,
p_c_grid + c_batch_offset,
nullptr ? nullptr : p_z_grid + z_batch_offset,
nullptr ? nullptr : p_lse_grid + lse_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
lse_grid_desc_m,
block_2_ctile_map,
c0_matrix_mask,
p_dropout_in_16bits,
p_dropout_rescale,
ph,
0);
}
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -216,6 +253,7 @@ template <index_t NumDimG, ...@@ -216,6 +253,7 @@ template <index_t NumDimG,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
: public DeviceBatchedMultiheadAttentionForward<NumDimG, : public DeviceBatchedMultiheadAttentionForward<NumDimG,
...@@ -476,7 +514,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -476,7 +514,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
Deterministic>;
// Argument // Argument
// FIXME: constness // FIXME: constness
...@@ -695,7 +734,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -695,7 +734,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_; (Deterministic ? 1
: arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_)) *
arg.batch_count_;
// Gemm0_K // Gemm0_K
const auto K = const auto K =
...@@ -703,65 +744,67 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -703,65 +744,67 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_, auto launch_kernel =
auto is_dropout_, [&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
auto is_lse_storing_) { const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle<
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle< GridwiseGemm,
GridwiseGemm, ADataType, // TODO: distiguish A/B datatype
ADataType, // TODO: distiguish A/B datatype CDataType,
CDataType, ZDataType,
ZDataType, LSEDataType,
LSEDataType, GemmAccDataType,
GemmAccDataType, AElementwiseOperation,
AElementwiseOperation, BElementwiseOperation,
BElementwiseOperation, AccElementwiseOperation,
AccElementwiseOperation, B1ElementwiseOperation,
B1ElementwiseOperation, CElementwiseOperation,
CElementwiseOperation, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, DeviceOp::LSEGridDesc_M,
DeviceOp::LSEGridDesc_M, typename GridwiseGemm::DefaultBlock2CTileMap,
typename GridwiseGemm::DefaultBlock2CTileMap, ComputeBasePtrOfStridedBatch,
ComputeBasePtrOfStridedBatch, C0MatrixMask,
C0MatrixMask, has_main_k_block_loop_,
has_main_k_block_loop_, is_dropout_,
is_dropout_, is_lse_storing_,
is_lse_storing_>; Deterministic>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(
kernel, stream_config,
dim3(grid_size), kernel,
dim3(BlockSize), dim3(grid_size),
0, dim3(BlockSize),
arg.p_a_grid_, 0,
arg.p_b_grid_, arg.p_a_grid_,
arg.p_b1_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_b1_grid_,
arg.p_z_grid_, arg.p_c_grid_,
arg.p_lse_grid_, arg.p_z_grid_,
arg.a_element_op_, arg.p_lse_grid_,
arg.b_element_op_, arg.a_element_op_,
arg.acc_element_op_, arg.b_element_op_,
arg.b1_element_op_, arg.acc_element_op_,
arg.c_element_op_, arg.b1_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.c_element_op_,
arg.b_grid_desc_bk0_n_bk1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.b1_grid_desc_bk0_n_bk1_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.lse_grid_desc_m_, arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.block_2_ctile_map_, arg.lse_grid_desc_m_,
arg.batch_count_, arg.block_2_ctile_map_,
arg.compute_base_ptr_of_batch_, arg.batch_count_,
arg.c0_matrix_mask_, arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_),
arg.p_dropout_in_16bits_, arg.compute_base_ptr_of_batch_,
arg.p_dropout_rescale_, arg.c0_matrix_mask_,
arg.seed_, arg.p_dropout_in_16bits_,
arg.offset_); arg.p_dropout_rescale_,
}; arg.seed_,
arg.offset_);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop // to concern Gemm0's loop
......
...@@ -79,7 +79,7 @@ __global__ void ...@@ -79,7 +79,7 @@ __global__ void
// per-group batch offset // per-group batch offset
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane( const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch); (block_id - arg_ptr[group_id].block_start_) / (Deterministic ? 1 : num_blocks_per_batch));
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
...@@ -104,38 +104,36 @@ __global__ void ...@@ -104,38 +104,36 @@ __global__ void
{ {
for(index_t i = 0; i < num_blocks_per_batch; i++) for(index_t i = 0; i < num_blocks_per_batch; i++)
{ {
if(((block_id - arg_ptr[group_id].block_start_) % num_blocks_per_batch) == i) GridwiseGemm::template Run<HasMainKBlockLoop>(
{ arg_ptr[group_id].p_a_grid_ + a_batch_offset,
GridwiseGemm::template Run<HasMainKBlockLoop>( arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_a_grid_ + a_batch_offset, z_matrix_ptr,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
z_matrix_ptr, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, p_shared,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, a_element_op,
p_shared, b_element_op,
a_element_op, acc_element_op,
b_element_op, b1_element_op,
acc_element_op, c_element_op,
b1_element_op, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
c_element_op, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_, arg_ptr[group_id].vgrad_grid_desc_n_o_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_,
arg_ptr[group_id].vgrad_grid_desc_n_o_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].ygrad_grid_desc_o0_m_o1_, arg_ptr[group_id].c0_matrix_mask_,
arg_ptr[group_id].block_2_ctile_map_, p_dropout,
arg_ptr[group_id].c0_matrix_mask_, ph,
p_dropout, i);
ph);
}
} }
} }
else else
...@@ -168,7 +166,8 @@ __global__ void ...@@ -168,7 +166,8 @@ __global__ void
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
p_dropout, p_dropout,
ph); ph,
0);
} }
#else #else
ignore = group_kernel_args; ignore = group_kernel_args;
...@@ -643,7 +642,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -643,7 +642,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -825,7 +825,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -825,7 +825,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0); const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
const index_t grid_size_grp = const index_t grid_size_grp =
block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o) * batch_count; (Deterministic ? 1 : block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o)) *
batch_count;
const index_t BlockEnd = grid_size_ + grid_size_grp; const index_t BlockEnd = grid_size_ + grid_size_grp;
// batch stride // batch stride
const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch( const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
......
...@@ -79,7 +79,7 @@ __global__ void ...@@ -79,7 +79,7 @@ __global__ void
// per-group batch offset // per-group batch offset
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane( const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch); (block_id - arg_ptr[group_id].block_start_) / (Deterministic ? 1 : num_blocks_per_batch));
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
...@@ -104,38 +104,36 @@ __global__ void ...@@ -104,38 +104,36 @@ __global__ void
{ {
for(index_t i = 0; i < num_blocks_per_batch; i++) for(index_t i = 0; i < num_blocks_per_batch; i++)
{ {
if(((block_id - arg_ptr[group_id].block_start_) % num_blocks_per_batch) == i) GridwiseGemm::template Run<HasMainKBlockLoop>(
{ arg_ptr[group_id].p_a_grid_ + a_batch_offset,
GridwiseGemm::template Run<HasMainKBlockLoop>( arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_a_grid_ + a_batch_offset, z_matrix_ptr,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
z_matrix_ptr, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, p_shared,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, a_element_op,
p_shared, b_element_op,
a_element_op, acc_element_op,
b_element_op, b1_element_op,
acc_element_op, c_element_op,
b1_element_op, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
c_element_op, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_, arg_ptr[group_id].vgrad_grid_desc_n_o_,
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_,
arg_ptr[group_id].vgrad_grid_desc_n_o_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].ygrad_grid_desc_m0_o_m1_, arg_ptr[group_id].c0_matrix_mask_,
arg_ptr[group_id].block_2_ctile_map_, p_dropout,
arg_ptr[group_id].c0_matrix_mask_, ph,
p_dropout, i);
ph);
}
} }
} }
else else
...@@ -168,7 +166,8 @@ __global__ void ...@@ -168,7 +166,8 @@ __global__ void
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
p_dropout, p_dropout,
ph); ph,
0);
} }
#else #else
ignore = group_kernel_args; ignore = group_kernel_args;
...@@ -636,7 +635,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -636,7 +635,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -818,7 +818,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -818,7 +818,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0); const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
const index_t grid_size_grp = const index_t grid_size_grp =
block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o) * batch_count; (Deterministic ? 1 : block_2_ctile_map.CalculateGridSize(y_grid_desc_m_o)) *
batch_count;
const index_t BlockEnd = grid_size_ + grid_size_grp; const index_t BlockEnd = grid_size_ + grid_size_grp;
// batch stride // batch stride
const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch( const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
......
...@@ -33,7 +33,8 @@ template <typename GridwiseGemm, ...@@ -33,7 +33,8 @@ template <typename GridwiseGemm,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool IsDropout, bool IsDropout,
bool IsLseStoring> bool IsLseStoring,
bool Deterministic>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -83,7 +84,7 @@ __global__ void ...@@ -83,7 +84,7 @@ __global__ void
// per-group batch offset // per-group batch offset
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane( const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch); (block_id - arg_ptr[group_id].block_start_) / (Deterministic ? 1 : num_blocks_per_batch));
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
...@@ -98,33 +99,74 @@ __global__ void ...@@ -98,33 +99,74 @@ __global__ void
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>( if constexpr(Deterministic)
arg_ptr[group_id].p_a_grid_ + a_batch_offset, {
arg_ptr[group_id].p_b_grid_ + b_batch_offset, for(index_t i = 0; i < num_blocks_per_batch; i++)
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, {
arg_ptr[group_id].p_c_grid_ + c_batch_offset, GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr arg_ptr[group_id].p_a_grid_ + a_batch_offset,
: arg_ptr[group_id].p_z_grid_ + z_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_lse_grid_ == nullptr ? nullptr arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
: arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, arg_ptr[group_id].p_z_grid_ == nullptr
p_shared, ? nullptr
a_element_op, : arg_ptr[group_id].p_z_grid_ + z_batch_offset,
b_element_op, arg_ptr[group_id].p_lse_grid_ == nullptr
acc_element_op, ? nullptr
b1_element_op, : arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
c_element_op, // arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, p_shared,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, a_element_op,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, b_element_op,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, acc_element_op,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, b1_element_op,
arg_ptr[group_id].lse_grid_desc_m_, c_element_op,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
p_dropout_in_16bits, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
p_dropout_rescale, arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
ph); arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
p_dropout_in_16bits,
p_dropout_rescale,
ph,
i);
}
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, IsDropout, IsLseStoring>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_lse_grid_ == nullptr
? nullptr
: arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
// arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_,
p_dropout_in_16bits,
p_dropout_rescale,
ph,
0);
}
#else #else
ignore = group_kernel_args; ignore = group_kernel_args;
ignore = group_count; ignore = group_count;
...@@ -206,6 +248,7 @@ template <index_t NumDimG, ...@@ -206,6 +248,7 @@ template <index_t NumDimG,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
bool Deterministic,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
: public DeviceGroupedMultiheadAttentionForward<NumDimG, : public DeviceGroupedMultiheadAttentionForward<NumDimG,
...@@ -487,7 +530,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -487,7 +530,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
Transform::matrix_padder.PadN, Transform::matrix_padder.PadN,
MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle,
Deterministic>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>; using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
...@@ -638,7 +682,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -638,7 +682,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart); const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart);
const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0); const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0);
const index_t grid_size_grp = const index_t grid_size_grp =
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * batch_count; (Deterministic ? 1 : block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n)) *
batch_count;
const index_t BlockEnd = grid_size_ + grid_size_grp; const index_t BlockEnd = grid_size_ + grid_size_grp;
// batch stride // batch stride
...@@ -778,7 +823,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle ...@@ -778,7 +823,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle
CElementwiseOperation, CElementwiseOperation,
has_main_k_block_loop_, has_main_k_block_loop_,
is_dropout_, is_dropout_,
is_lse_storing_>; is_lse_storing_,
Deterministic>;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
......
...@@ -86,6 +86,7 @@ template <typename InputDataType, ...@@ -86,6 +86,7 @@ template <typename InputDataType,
LoopScheduler LoopSched, LoopScheduler LoopSched,
bool PadN, bool PadN,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{ {
...@@ -1265,7 +1266,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1265,7 +1266,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
const float p_drop, const float p_drop,
ck::philox& ph) ck::philox& ph,
const index_t block_idx_m)
{ {
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
...@@ -1305,9 +1307,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1305,9 +1307,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
return; return;
} }
const index_t block_work_idx_m = Deterministic ? block_idx_m : block_work_idx[I0];
// HACK: this force m/o_block_data_idx_on_grid into SGPR // HACK: this force m/o_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx_m * MPerBlock);
// const index_t o_block_data_idx_on_grid = // const index_t o_block_data_idx_on_grid =
// __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); // __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
...@@ -1512,7 +1516,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1512,7 +1516,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
1, 1,
false>{ false>{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl, lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(block_work_idx[I0], // mblock make_multi_index(block_work_idx_m, // mblock
acc0_thread_origin[I0], // mrepeat acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl acc0_thread_origin[I4])}; // mperxdl
...@@ -1574,15 +1578,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1574,15 +1578,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx[I0], // MBlockId make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId 0, // NBlockId
0, // mrepeat 0, // mrepeat
0, // nrepeat 0, // nrepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
wave_id[I1], // NWaveId wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl wave_m_n_id[I1], // MPerXdl
0, // group 0, // group
wave_m_n_id[I0], // NInputIndex wave_m_n_id[I0], // NInputIndex
0), 0),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
...@@ -1720,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1720,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
ygrad_thread_cluster_idx * ygrad_thread_desc_m_o.GetLengths(); ygrad_thread_cluster_idx * ygrad_thread_desc_m_o.GetLengths();
const auto y_thread_data_on_grid_idx = const auto y_thread_data_on_grid_idx =
make_multi_index( make_multi_index(
block_work_idx[I0], I0, I0 /* all WGs start from o_block_idx = 0 */, I0) + block_work_idx_m, I0, I0 /* all WGs start from o_block_idx = 0 */, I0) +
y_thread_data_on_block_idx; y_thread_data_on_block_idx;
// performs for y // performs for y
...@@ -2320,7 +2324,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2320,7 +2324,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0), make_multi_index(0, 0, 0, 0),
qgrad_grid_desc_mblock_mperblock_kblock_kperblock, qgrad_grid_desc_mblock_mperblock_kblock_kperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), make_multi_index(block_work_idx_m, 0, block_work_idx[I1], 0),
c_element_op}; c_element_op};
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
......
...@@ -86,6 +86,7 @@ template <typename InputDataType, ...@@ -86,6 +86,7 @@ template <typename InputDataType,
LoopScheduler LoopSched, LoopScheduler LoopSched,
bool PadN, bool PadN,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{ {
...@@ -1175,7 +1176,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1175,7 +1176,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const Block2CTileMap& block_2_ctile_map, const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
const float p_drop, const float p_drop,
ck::philox& ph) ck::philox& ph,
const index_t block_idx_m)
{ {
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
...@@ -1215,9 +1217,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1215,9 +1217,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
return; return;
} }
const index_t block_work_idx_m = Deterministic ? block_idx_m : block_work_idx[I0];
// HACK: this force m/o_block_data_idx_on_grid into SGPR // HACK: this force m/o_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx_m * MPerBlock);
const index_t o_block_data_idx_on_grid = const index_t o_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
...@@ -1444,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1444,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1, 1,
false>{ false>{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl, lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(block_work_idx[I0], // mblock make_multi_index(block_work_idx_m, // mblock
acc0_thread_origin[I0], // mrepeat acc0_thread_origin[I0], // mrepeat
acc0_thread_origin[I2], // mwave acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4])}; // mperxdl acc0_thread_origin[I4])}; // mperxdl
...@@ -1506,7 +1510,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1506,7 +1510,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx[I0], // MBlockId make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId 0, // NBlockId
0, // mrepeat 0, // mrepeat
0, // nrepeat 0, // nrepeat
...@@ -1643,7 +1647,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1643,7 +1647,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths(); y_thread_cluster_idx * y_thread_desc_m0_m1_o0_o1.GetLengths();
const auto y_thread_data_on_grid_idx = const auto y_thread_data_on_grid_idx =
make_multi_index( make_multi_index(
block_work_idx[I0], I0, I0 /* all WGs start from o_block_idx = 0 */, I0) + block_work_idx_m, I0, I0 /* all WGs start from o_block_idx = 0 */, I0) +
y_thread_data_on_block_idx; y_thread_data_on_block_idx;
// performs double duty for both y and ygrad // performs double duty for both y and ygrad
...@@ -2270,7 +2274,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2270,7 +2274,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0), make_multi_index(0, 0, 0, 0),
qgrad_grid_desc_mblock_mperblock_kblock_kperblock, qgrad_grid_desc_mblock_mperblock_kblock_kperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), make_multi_index(block_work_idx_m, 0, block_work_idx[I1], 0),
c_element_op}; c_element_op};
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
......
...@@ -85,6 +85,7 @@ template <typename FloatAB, ...@@ -85,6 +85,7 @@ template <typename FloatAB,
LoopScheduler LoopSched, LoopScheduler LoopSched,
bool PadN, bool PadN,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
bool Deterministic,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{ {
...@@ -445,7 +446,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -445,7 +446,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const C0MatrixMask& c0_matrix_mask, const C0MatrixMask& c0_matrix_mask,
const ushort p_dropout_in_16bits, const ushort p_dropout_in_16bits,
FloatGemmAcc p_dropout_rescale, FloatGemmAcc p_dropout_rescale,
ck::philox ph) ck::philox& ph,
const index_t block_idx_m)
{ {
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -470,9 +472,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -470,9 +472,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
return; return;
} }
const index_t block_work_idx_m = Deterministic ? block_idx_m : block_work_idx[I0];
// HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR // HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx_m * MPerBlock);
const index_t gemm1_n_block_data_idx_on_grid = const index_t gemm1_n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
...@@ -835,7 +839,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -835,7 +839,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
false>{lse_grid_desc_mblock_mrepeat_mwave_mperxdl, false>{lse_grid_desc_mblock_mrepeat_mwave_mperxdl,
make_multi_index(block_work_idx[I0], // mblock make_multi_index(block_work_idx_m, // mblock
0, // mrepeat 0, // mrepeat
acc0_thread_origin[I2], // mwave acc0_thread_origin[I2], // mwave
acc0_thread_origin[I4]), // mperxdl acc0_thread_origin[I4]), // mperxdl
...@@ -897,15 +901,15 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -897,15 +901,15 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector 1, // DstScalarStrideInVector
true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, true>{z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
make_multi_index(block_work_idx[I0], // MBlockId make_multi_index(block_work_idx_m, // MBlockId
0, // NBlockId 0, // NBlockId
0, // mrepeat 0, // mrepeat
0, // nrepeat 0, // nrepeat
wave_id[I0], // MWaveId wave_id[I0], // MWaveId
wave_id[I1], // NWaveId wave_id[I1], // NWaveId
wave_m_n_id[I1], // MPerXdl wave_m_n_id[I1], // MPerXdl
0, // group 0, // group
wave_m_n_id[I0], // NInputIndex wave_m_n_id[I0], // NInputIndex
0), 0),
tensor_operation::element_wise::PassThrough{}}; tensor_operation::element_wise::PassThrough{}};
...@@ -1319,7 +1323,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -1319,7 +1323,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(0, 0, 0, 0), make_multi_index(0, 0, 0, 0),
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), make_multi_index(block_work_idx_m, 0, block_work_idx[I1], 0),
c_element_op}; c_element_op};
// space filling curve for threadwise C in VGPR // space filling curve for threadwise C in VGPR
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment