Commit 3b57967f authored by danyao12's avatar danyao12
Browse files

batched&grouped train output datatype can be selected

parent d042e931
...@@ -71,7 +71,8 @@ using Scale = ck::tensor_operation::element_wise::Scale; ...@@ -71,7 +71,8 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough; using QKVElementOp = PassThrough;
using YElementOp = PassThrough; using YElementOp = PassThrough;
using DataType = BF16; using InputDataType = BF16;
using OutputDataType = F32;
using GemmDataType = BF16; using GemmDataType = BF16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
...@@ -85,6 +86,9 @@ static constexpr ck::index_t NumDimM = 1; ...@@ -85,6 +86,9 @@ static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1; static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1; static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1; static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
...@@ -112,10 +116,10 @@ using DeviceGemmInstanceFWD = ...@@ -112,10 +116,10 @@ using DeviceGemmInstanceFWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
DataType, InputDataType,
DataType, InputDataType,
DataType, InputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -182,7 +186,8 @@ using DeviceGemmInstanceBWD = ...@@ -182,7 +186,8 @@ using DeviceGemmInstanceBWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -240,8 +245,8 @@ using DeviceGemmInstanceBWD = ...@@ -240,8 +245,8 @@ using DeviceGemmInstanceBWD =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
...@@ -250,10 +255,10 @@ using DeviceGemmInstanceFWD = ...@@ -250,10 +255,10 @@ using DeviceGemmInstanceFWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
DataType, InputDataType,
DataType, InputDataType,
DataType, InputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -320,7 +325,8 @@ using DeviceGemmInstanceBWD = ...@@ -320,7 +325,8 @@ using DeviceGemmInstanceBWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -378,8 +384,8 @@ using DeviceGemmInstanceBWD = ...@@ -378,8 +384,8 @@ using DeviceGemmInstanceBWD =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
// using DeviceGemmInstanceBWD = // using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...@@ -388,7 +394,8 @@ using DeviceGemmInstanceBWD = ...@@ -388,7 +394,8 @@ using DeviceGemmInstanceBWD =
// NumDimN, // NumDimN,
// NumDimK, // NumDimK,
// NumDimO, // NumDimO,
// DataType, // InputDataType,
// OutputDataType,
// GemmDataType, // GemmDataType,
// ZDataType, // ZDataType,
// LSEDataType, // LSEDataType,
...@@ -446,8 +453,8 @@ using DeviceGemmInstanceBWD = ...@@ -446,8 +453,8 @@ using DeviceGemmInstanceBWD =
// 1, // CShuffleMXdlPerWavePerShuffle // 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle // 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock // S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 8, // CShuffleBlockTransferScalarPerVector_NPerBlock // CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec>; // MaskingSpecialization // MaskingSpec>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle<
...@@ -456,10 +463,10 @@ using DeviceGemmInstanceFWD = ...@@ -456,10 +463,10 @@ using DeviceGemmInstanceFWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
DataType, InputDataType,
DataType, InputDataType,
DataType, InputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -526,7 +533,8 @@ using DeviceGemmInstanceBWD = ...@@ -526,7 +533,8 @@ using DeviceGemmInstanceBWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -584,14 +592,14 @@ using DeviceGemmInstanceBWD = ...@@ -584,14 +592,14 @@ using DeviceGemmInstanceBWD =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle 4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#endif #endif
// Ref Gemm0: S = alpha * Q * K^T // Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out // fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
DataType, InputDataType,
AccDataType, AccDataType,
AccDataType, AccDataType,
PassThrough, PassThrough,
...@@ -601,13 +609,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -601,13 +609,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Softmax: P = Softmax(S) // Ref Softmax: P = Softmax(S)
// fp32 in, fp16 out // fp32 in, fp16 out
using ReferenceSoftmaxInstance = using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, DataType, AccDataType>; ck::tensor_operation::host::ReferenceSoftmax<AccDataType, InputDataType, AccDataType>;
// Ref Gemm1: Y = P * V // Ref Gemm1: Y = P * V
// fp16 in, fp16 out // fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType, using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
DataType, InputDataType,
DataType, InputDataType,
AccDataType, AccDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -615,16 +623,25 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -615,16 +623,25 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Gemm for backward pass // Ref Gemm for backward pass
// fp16 in, fp16 out // fp16 in, fp16 out
using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType, using ReferenceGemm0GradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
DataType, InputDataType,
DataType, InputDataType,
AccDataType, AccDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>; Scale>;
using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
InputDataType,
OutputDataType,
AccDataType,
PassThrough,
PassThrough,
Scale>;
// Ref dropout // Ref dropout
using ReferenceDropoutInstance = using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ushort, DataType, DataType>; ck::tensor_operation::host::ReferenceDropout<ushort, InputDataType, InputDataType>;
template <typename TensorQ, template <typename TensorQ,
typename TensorK, typename TensorK,
...@@ -811,17 +828,17 @@ int run(int argc, char* argv[]) ...@@ -811,17 +828,17 @@ int run(int argc, char* argv[])
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M] std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<ZDataType> z_fwd_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_fwd_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<ZDataType> z_bwd_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_bwd_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<DataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<InputDataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<DataType> y_gs_ms_os_device_result(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os_device_result(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<DataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<DataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
...@@ -830,46 +847,46 @@ int run(int argc, char* argv[]) ...@@ -830,46 +847,46 @@ int run(int argc, char* argv[])
std::cout << "y_gs_ms_os: " << y_gs_ms_os_device_result.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os_device_result.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms_device_result.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms_device_result.mDesc << std::endl;
z_fwd_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{0}); z_fwd_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
z_bwd_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{0}); z_bwd_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
break; break;
case 2: case 2:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
break; break;
case 3: case 3:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-5, 5});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
break; break;
case 4: case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
break; break;
case 5: case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...] // dO dot O = [0; 1; 2; ...]
break; break;
case 6: case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
...@@ -880,10 +897,10 @@ int run(int argc, char* argv[]) ...@@ -880,10 +897,10 @@ int run(int argc, char* argv[])
// //
break; break;
default: default:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -895,21 +912,22 @@ int run(int argc, char* argv[]) ...@@ -895,21 +912,22 @@ int run(int argc, char* argv[])
} }
// qkv gradients have the same descriptor as with qkv // qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem k_device_buf(sizeof(InputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem z_fwd_device_buf(sizeof(ZDataType) * z_fwd_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem z_fwd_device_buf(sizeof(ZDataType) * z_fwd_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem z_bwd_device_buf(sizeof(ZDataType) * z_bwd_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem z_bwd_device_buf(sizeof(ZDataType) * z_bwd_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem v_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem v_device_buf(sizeof(InputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem y_device_buf(sizeof(DataType) * y_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); DeviceMem y_device_buf(sizeof(InputDataType) *
y_gs_ms_os_device_result.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * DeviceMem lse_device_buf(sizeof(LSEDataType) *
lse_gs_ms_device_result.mDesc.GetElementSpaceSize()); lse_gs_ms_device_result.mDesc.GetElementSpaceSize());
DeviceMem qgrad_device_buf(sizeof(DataType) * DeviceMem qgrad_device_buf(sizeof(OutputDataType) *
qgrad_gs_ms_ks_device_result.mDesc.GetElementSpaceSize()); qgrad_gs_ms_ks_device_result.mDesc.GetElementSpaceSize());
DeviceMem kgrad_device_buf(sizeof(DataType) * DeviceMem kgrad_device_buf(sizeof(OutputDataType) *
kgrad_gs_ns_ks_device_result.mDesc.GetElementSpaceSize()); kgrad_gs_ns_ks_device_result.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(DataType) * DeviceMem vgrad_device_buf(sizeof(OutputDataType) *
vgrad_gs_os_ns_device_result.mDesc.GetElementSpaceSize()); vgrad_gs_os_ns_device_result.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(DataType) * ygrad_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem ygrad_device_buf(sizeof(InputDataType) * ygrad_gs_ms_os.mDesc.GetElementSpaceSize());
q_device_buf.ToDevice(q_gs_ms_ks.mData.data()); q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data()); k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
...@@ -926,10 +944,10 @@ int run(int argc, char* argv[]) ...@@ -926,10 +944,10 @@ int run(int argc, char* argv[])
if(time_kernel) if(time_kernel)
{ {
auto argument_fwd = gemm_fwd.MakeArgument( auto argument_fwd = gemm_fwd.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
...@@ -967,10 +985,11 @@ int run(int argc, char* argv[]) ...@@ -967,10 +985,11 @@ int run(int argc, char* argv[])
float ave_time_fwd = invoker_fwd.Run(argument_fwd, StreamConfig{nullptr, true}); float ave_time_fwd = invoker_fwd.Run(argument_fwd, StreamConfig{nullptr, true});
std::size_t flop_fwd = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; std::size_t flop_fwd = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype_fwd = (sizeof(DataType) * M * K + sizeof(DataType) * K * N + std::size_t num_btype_fwd =
sizeof(DataType) * N * O + sizeof(DataType) * M * O) * (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
BatchCount; sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O) *
BatchCount;
float tflops_fwd = static_cast<float>(flop_fwd) / 1.E9 / ave_time_fwd; float tflops_fwd = static_cast<float>(flop_fwd) / 1.E9 / ave_time_fwd;
...@@ -981,16 +1000,16 @@ int run(int argc, char* argv[]) ...@@ -981,16 +1000,16 @@ int run(int argc, char* argv[])
// not need output z matrix // not need output z matrix
auto argument_bwd = gemm_bwd.MakeArgument( auto argument_bwd = gemm_bwd.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr static_cast<ZDataType*>(nullptr), // set to nullptr
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(ygrad_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
...@@ -1026,10 +1045,13 @@ int run(int argc, char* argv[]) ...@@ -1026,10 +1045,13 @@ int run(int argc, char* argv[])
// 3x MNK + 2x MNO // 3x MNK + 2x MNO
std::size_t flop_bwd = (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount; std::size_t flop_bwd = (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE // Q/K/V/Y, dQ/dK/dV/dY, LSE
std::size_t num_btype_bwd = (sizeof(DataType) * M * K + sizeof(DataType) * K * N + std::size_t num_btype_bwd =
sizeof(DataType) * N * O + sizeof(DataType) * M * O) * (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
size_t(2) * BatchCount + sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) +
sizeof(LSEDataType) * M * BatchCount; sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N +
sizeof(OutputDataType) * N * O) *
BatchCount +
sizeof(LSEDataType) * M * BatchCount;
float tflops_bwd = static_cast<float>(flop_bwd) / 1.E9 / ave_time_bwd; float tflops_bwd = static_cast<float>(flop_bwd) / 1.E9 / ave_time_bwd;
...@@ -1042,39 +1064,39 @@ int run(int argc, char* argv[]) ...@@ -1042,39 +1064,39 @@ int run(int argc, char* argv[])
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
Tensor<DataType> y_gs_ms_os_host_result(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os_host_result(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<DataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<DataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<DataType> q_g_m_k({BatchCount, M, K}); Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K}); Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_fwd_g_m_n({BatchCount, M, N}); Tensor<ZDataType> z_fwd_g_m_n({BatchCount, M, N});
Tensor<ZDataType> z_bwd_g_m_n({BatchCount, M, N}); Tensor<ZDataType> z_bwd_g_m_n({BatchCount, M, N});
Tensor<DataType> v_g_n_o({BatchCount, N, O}); Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N}); Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<DataType> p_g_m_n({BatchCount, M, N}); Tensor<InputDataType> p_g_m_n({BatchCount, M, N});
Tensor<DataType> p_drop_g_m_n({BatchCount, M, N}); Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
Tensor<DataType> y_g_m_o({BatchCount, M, O}); Tensor<InputDataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M}); Tensor<LSEDataType> lse_g_m({BatchCount, M});
Tensor<DataType> qgrad_g_m_k({BatchCount, M, K}); Tensor<OutputDataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<DataType> kgrad_g_n_k({BatchCount, N, K}); Tensor<OutputDataType> kgrad_g_n_k({BatchCount, N, K});
Tensor<DataType> vgrad_g_n_o({BatchCount, N, O}); Tensor<OutputDataType> vgrad_g_n_o({BatchCount, N, O});
Tensor<DataType> sgrad_g_m_n({BatchCount, M, N}); Tensor<InputDataType> sgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_g_m_n({BatchCount, M, N}); Tensor<InputDataType> pgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_drop_g_m_n({BatchCount, M, N}); Tensor<InputDataType> pgrad_drop_g_m_n({BatchCount, M, N});
Tensor<DataType> ygrad_g_m_o({BatchCount, M, O}); Tensor<InputDataType> ygrad_g_m_o({BatchCount, M, O});
Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M}); Tensor<InputDataType> ygrad_dot_y_g_m({BatchCount, M});
// get kernel output matrixes // get kernel output matrixes
{ {
auto argument_fwd = gemm_fwd.MakeArgument( auto argument_fwd = gemm_fwd.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_fwd_device_buf.GetDeviceBuffer()), static_cast<ZDataType*>(z_fwd_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
...@@ -1125,16 +1147,16 @@ int run(int argc, char* argv[]) ...@@ -1125,16 +1147,16 @@ int run(int argc, char* argv[])
vgrad_device_buf.SetZero(); vgrad_device_buf.SetZero();
auto argument_bwd = gemm_bwd.MakeArgument( auto argument_bwd = gemm_bwd.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_bwd_device_buf.GetDeviceBuffer()), static_cast<ZDataType*>(z_bwd_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(ygrad_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
...@@ -1223,13 +1245,16 @@ int run(int argc, char* argv[]) ...@@ -1223,13 +1245,16 @@ int run(int argc, char* argv[])
#endif #endif
// Gradients // Gradients
auto ref_gemm_grad = ReferenceGemmGradInstance{}; auto ref_gemm0_grad = ReferenceGemm0GradInstance{};
auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker(); auto ref_gemm0_grad_invoker = ref_gemm0_grad.MakeInvoker();
using RefGemmGradArg = ReferenceGemmGradInstance::Argument; using RefGemm0GradArg = ReferenceGemm0GradInstance::Argument;
auto ref_gemm1_grad = ReferenceGemm1GradInstance{};
auto ref_gemm1_grad_invoker = ref_gemm1_grad.MakeInvoker();
using RefGemm1GradArg = ReferenceGemm1GradInstance::Argument;
// dP_dropout = dY * V^T // dP_dropout = dY * V^T
auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1}); auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm0_grad_invoker.Run(RefGemm0GradArg{
ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}}); ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
#if PRINT_HOST #if PRINT_HOST
{ {
...@@ -1256,7 +1281,7 @@ int run(int argc, char* argv[]) ...@@ -1256,7 +1281,7 @@ int run(int argc, char* argv[])
ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) * ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) *
ck::type_convert<AccDataType>(y_g_m_o(idx_gmo)); ck::type_convert<AccDataType>(y_g_m_o(idx_gmo));
} }
self(idx_gmn) = ck::type_convert<DataType>( self(idx_gmn) = ck::type_convert<InputDataType>(
ck::type_convert<AccDataType>(p_g_m_n(idx_gmn)) * ck::type_convert<AccDataType>(p_g_m_n(idx_gmn)) *
(ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y)); (ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y));
}); });
...@@ -1272,7 +1297,7 @@ int run(int argc, char* argv[]) ...@@ -1272,7 +1297,7 @@ int run(int argc, char* argv[])
#endif #endif
// dV = P_drop^T * dY // dV = P_drop^T * dY
auto p_drop_g_n_m = p_drop_g_m_n.Transpose({0, 2, 1}); auto p_drop_g_n_m = p_drop_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}}); p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}});
#if PRINT_HOST #if PRINT_HOST
{ {
...@@ -1284,7 +1309,7 @@ int run(int argc, char* argv[]) ...@@ -1284,7 +1309,7 @@ int run(int argc, char* argv[])
#endif #endif
// dQ = alpha * dS * K // dQ = alpha * dS * K
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
sgrad_g_m_n, k_g_n_k, qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}}); sgrad_g_m_n, k_g_n_k, qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
#if PRINT_HOST #if PRINT_HOST
{ {
...@@ -1297,7 +1322,7 @@ int run(int argc, char* argv[]) ...@@ -1297,7 +1322,7 @@ int run(int argc, char* argv[])
// dK = alpha * dS^T * Q // dK = alpha * dS^T * Q
auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1}); auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
sgrad_g_n_m, q_g_m_k, kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}}); sgrad_g_n_m, q_g_m_k, kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
#if PRINT_HOST #if PRINT_HOST
{ {
...@@ -1355,7 +1380,7 @@ int run(int argc, char* argv[]) ...@@ -1355,7 +1380,7 @@ int run(int argc, char* argv[])
double atol = 1e-3; double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01 // when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<DataType, ck::bhalf_t> || std::is_same_v<GemmDataType, ck::bhalf_t>) if(std::is_same_v<InputDataType, ck::bhalf_t> || std::is_same_v<GemmDataType, ck::bhalf_t>)
{ {
rtol = 1e-2; rtol = 1e-2;
atol = 1e-2; atol = 1e-2;
......
...@@ -70,7 +70,8 @@ using Scale = ck::tensor_operation::element_wise::Scale; ...@@ -70,7 +70,8 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough; using QKVElementOp = PassThrough;
using YElementOp = PassThrough; using YElementOp = PassThrough;
using DataType = BF16; using InputDataType = BF16;
using OutputDataType = F32;
using GemmDataType = BF16; using GemmDataType = BF16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
...@@ -84,6 +85,9 @@ static constexpr ck::index_t NumDimM = 1; ...@@ -84,6 +85,9 @@ static constexpr ck::index_t NumDimM = 1;
static constexpr ck::index_t NumDimN = 1; static constexpr ck::index_t NumDimN = 1;
static constexpr ck::index_t NumDimK = 1; static constexpr ck::index_t NumDimK = 1;
static constexpr ck::index_t NumDimO = 1; static constexpr ck::index_t NumDimO = 1;
// When OutputDataType == F32, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 4
// When OutputDataType == F16/BF16, bwd CShuffleBlockTransferScalarPerVector_NPerBlock = 8
static constexpr ck::index_t CShuffleBlockTransferScalarPerVector_NPerBlock = 4;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
#if USING_MASK #if USING_MASK
...@@ -111,10 +115,10 @@ using DeviceGemmInstanceFWD = ...@@ -111,10 +115,10 @@ using DeviceGemmInstanceFWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
DataType, InputDataType,
DataType, InputDataType,
DataType, InputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -181,7 +185,8 @@ using DeviceGemmInstanceBWD = ...@@ -181,7 +185,8 @@ using DeviceGemmInstanceBWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -239,8 +244,8 @@ using DeviceGemmInstanceBWD = ...@@ -239,8 +244,8 @@ using DeviceGemmInstanceBWD =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
...@@ -249,10 +254,10 @@ using DeviceGemmInstanceFWD = ...@@ -249,10 +254,10 @@ using DeviceGemmInstanceFWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
DataType, InputDataType,
DataType, InputDataType,
DataType, InputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -319,7 +324,8 @@ using DeviceGemmInstanceBWD = ...@@ -319,7 +324,8 @@ using DeviceGemmInstanceBWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -377,8 +383,8 @@ using DeviceGemmInstanceBWD = ...@@ -377,8 +383,8 @@ using DeviceGemmInstanceBWD =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
// using DeviceGemmInstanceBWD = // using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2< // ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
...@@ -387,7 +393,8 @@ using DeviceGemmInstanceBWD = ...@@ -387,7 +393,8 @@ using DeviceGemmInstanceBWD =
// NumDimN, // NumDimN,
// NumDimK, // NumDimK,
// NumDimO, // NumDimO,
// DataType, // InputDataType,
// OutputDataType,
// GemmDataType, // GemmDataType,
// ZDataType, // ZDataType,
// LSEDataType, // LSEDataType,
...@@ -445,8 +452,8 @@ using DeviceGemmInstanceBWD = ...@@ -445,8 +452,8 @@ using DeviceGemmInstanceBWD =
// 1, // CShuffleMXdlPerWavePerShuffle // 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle // 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock // S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 8, // CShuffleBlockTransferScalarPerVector_NPerBlock // CShuffleBlockTransferScalarPerVector_NPerBlock,
// MaskingSpec>; // MaskingSpecialization // MaskingSpec>;
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstanceFWD = using DeviceGemmInstanceFWD =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle<
...@@ -455,10 +462,10 @@ using DeviceGemmInstanceFWD = ...@@ -455,10 +462,10 @@ using DeviceGemmInstanceFWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
DataType, InputDataType,
DataType, InputDataType,
DataType, InputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -525,7 +532,8 @@ using DeviceGemmInstanceBWD = ...@@ -525,7 +532,8 @@ using DeviceGemmInstanceBWD =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -583,14 +591,14 @@ using DeviceGemmInstanceBWD = ...@@ -583,14 +591,14 @@ using DeviceGemmInstanceBWD =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle 4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock CShuffleBlockTransferScalarPerVector_NPerBlock, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#endif #endif
// Ref Gemm0: S = alpha * Q * K^T // Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out // fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
DataType, InputDataType,
AccDataType, AccDataType,
AccDataType, AccDataType,
PassThrough, PassThrough,
...@@ -600,13 +608,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -600,13 +608,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Softmax: P = Softmax(S) // Ref Softmax: P = Softmax(S)
// fp32 in, fp16 out // fp32 in, fp16 out
using ReferenceSoftmaxInstance = using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, DataType, AccDataType>; ck::tensor_operation::host::ReferenceSoftmax<AccDataType, InputDataType, AccDataType>;
// Ref Gemm1: Y = P * V // Ref Gemm1: Y = P * V
// fp16 in, fp16 out // fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType, using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
DataType, InputDataType,
DataType, InputDataType,
AccDataType, AccDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -614,16 +622,25 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -614,16 +622,25 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Gemm for backward pass // Ref Gemm for backward pass
// fp16 in, fp16 out // fp16 in, fp16 out
using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType, using ReferenceGemm0GradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
DataType, InputDataType,
DataType, InputDataType,
AccDataType, AccDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>; Scale>;
using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
InputDataType,
OutputDataType,
AccDataType,
PassThrough,
PassThrough,
Scale>;
// Ref dropout // Ref dropout
using ReferenceDropoutInstance = using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ushort, DataType, DataType>; ck::tensor_operation::host::ReferenceDropout<ushort, InputDataType, InputDataType>;
template <typename TensorQ, template <typename TensorQ,
typename TensorK, typename TensorK,
...@@ -764,28 +781,28 @@ int run(int argc, char* argv[]) ...@@ -764,28 +781,28 @@ int run(int argc, char* argv[])
std::vector<void*> p_vgrad; std::vector<void*> p_vgrad;
std::vector<const void*> p_ygrad; std::vector<const void*> p_ygrad;
std::vector<Tensor<DataType>> q_g_m_ks; std::vector<Tensor<InputDataType>> q_g_m_ks;
std::vector<Tensor<DataType>> k_g_n_ks; std::vector<Tensor<InputDataType>> k_g_n_ks;
std::vector<Tensor<ZDataType>> z_fwd_g_m_ns; std::vector<Tensor<ZDataType>> z_fwd_g_m_ns;
std::vector<Tensor<ZDataType>> z_bwd_g_m_ns; std::vector<Tensor<ZDataType>> z_bwd_g_m_ns;
std::vector<Tensor<DataType>> v_g_n_os; std::vector<Tensor<InputDataType>> v_g_n_os;
std::vector<Tensor<AccDataType>> s_g_m_ns; std::vector<Tensor<AccDataType>> s_g_m_ns;
std::vector<Tensor<DataType>> p_g_m_ns; std::vector<Tensor<InputDataType>> p_g_m_ns;
std::vector<Tensor<DataType>> y_g_m_os; std::vector<Tensor<InputDataType>> y_g_m_os;
std::vector<Tensor<LSEDataType>> lse_g_ms; std::vector<Tensor<LSEDataType>> lse_g_ms;
std::vector<Tensor<DataType>> p_drop_g_m_ns; std::vector<Tensor<InputDataType>> p_drop_g_m_ns;
std::vector<Tensor<DataType>> q_tensors; std::vector<Tensor<InputDataType>> q_tensors;
std::vector<Tensor<DataType>> k_tensors; std::vector<Tensor<InputDataType>> k_tensors;
std::vector<Tensor<DataType>> v_tensors; std::vector<Tensor<InputDataType>> v_tensors;
std::vector<Tensor<DataType>> y_tensors; std::vector<Tensor<InputDataType>> y_tensors;
std::vector<Tensor<ZDataType>> z_fwd_tensors; std::vector<Tensor<ZDataType>> z_fwd_tensors;
std::vector<Tensor<ZDataType>> z_bwd_tensors; std::vector<Tensor<ZDataType>> z_bwd_tensors;
std::vector<Tensor<LSEDataType>> lse_tensors; std::vector<Tensor<LSEDataType>> lse_tensors;
std::vector<Tensor<DataType>> qgrad_tensors; std::vector<Tensor<OutputDataType>> qgrad_tensors;
std::vector<Tensor<DataType>> kgrad_tensors; std::vector<Tensor<OutputDataType>> kgrad_tensors;
std::vector<Tensor<DataType>> vgrad_tensors; std::vector<Tensor<OutputDataType>> vgrad_tensors;
std::vector<Tensor<DataType>> ygrad_tensors; std::vector<Tensor<InputDataType>> ygrad_tensors;
std::vector<DeviceMemPtr> q_tensors_device; std::vector<DeviceMemPtr> q_tensors_device;
std::vector<DeviceMemPtr> k_tensors_device; std::vector<DeviceMemPtr> k_tensors_device;
...@@ -885,24 +902,26 @@ int run(int argc, char* argv[]) ...@@ -885,24 +902,26 @@ int run(int argc, char* argv[])
int BatchCount = G0 * G1; int BatchCount = G0 * G1;
flop_fwd += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; flop_fwd += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
num_byte_fwd += (sizeof(DataType) * M * K + sizeof(DataType) * K * N + num_byte_fwd += (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
sizeof(DataType) * N * O + sizeof(DataType) * M * O) * sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O) *
BatchCount; BatchCount;
flop_bwd += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount; flop_bwd += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE // Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte_bwd += (sizeof(DataType) * M * K + sizeof(DataType) * K * N + num_byte_bwd += (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
sizeof(DataType) * N * O + sizeof(DataType) * M * O) * sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) +
size_t(2) * BatchCount + sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N +
sizeof(OutputDataType) * N * O) *
BatchCount +
sizeof(LSEDataType) * M * BatchCount; sizeof(LSEDataType) * M * BatchCount;
Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<ZDataType> z_fwd_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_fwd_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<ZDataType> z_bwd_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_bwd_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<DataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<InputDataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
if(i < 4) if(i < 4)
{ {
...@@ -913,46 +932,46 @@ int run(int argc, char* argv[]) ...@@ -913,46 +932,46 @@ int run(int argc, char* argv[])
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
} }
z_fwd_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{0}); z_fwd_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
z_bwd_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{0}); z_bwd_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
break; break;
case 2: case 2:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
break; break;
case 3: case 3:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-5, 5});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
break; break;
case 4: case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{2}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{2});
break; break;
case 5: case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...] // dO dot O = [0; 1; 2; ...]
break; break;
case 6: case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
...@@ -963,10 +982,11 @@ int run(int argc, char* argv[]) ...@@ -963,10 +982,11 @@ int run(int argc, char* argv[])
// //
break; break;
default: default:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -977,16 +997,16 @@ int run(int argc, char* argv[]) ...@@ -977,16 +997,16 @@ int run(int argc, char* argv[])
// = 0 // = 0
} }
Tensor<DataType> q_g_m_k({BatchCount, M, K}); Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K}); Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_fwd_g_m_n({BatchCount, M, N}); Tensor<ZDataType> z_fwd_g_m_n({BatchCount, M, N});
Tensor<ZDataType> z_bwd_g_m_n({BatchCount, M, N}); Tensor<ZDataType> z_bwd_g_m_n({BatchCount, M, N});
Tensor<DataType> v_g_n_o({BatchCount, N, O}); Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N}); Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<DataType> p_g_m_n({BatchCount, M, N}); Tensor<InputDataType> p_g_m_n({BatchCount, M, N});
Tensor<DataType> y_g_m_o({BatchCount, M, O}); Tensor<InputDataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M}); Tensor<LSEDataType> lse_g_m({BatchCount, M});
Tensor<DataType> p_drop_g_m_n({BatchCount, M, N}); Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
q_gs_ms_ks.ForEach([&](auto& self, auto idx) { q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
...@@ -1019,27 +1039,27 @@ int run(int argc, char* argv[]) ...@@ -1019,27 +1039,27 @@ int run(int argc, char* argv[])
ygrad_tensors.push_back(ygrad_gs_ms_os); ygrad_tensors.push_back(ygrad_gs_ms_os);
q_tensors_device.emplace_back( q_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
k_tensors_device.emplace_back( k_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * k_gs_ns_ks.GetElementSpaceSize()));
z_fwd_tensors_device.emplace_back( z_fwd_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ZDataType) * z_fwd_gs_ms_ns.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(ZDataType) * z_fwd_gs_ms_ns.GetElementSpaceSize()));
z_bwd_tensors_device.emplace_back( z_bwd_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ZDataType) * z_bwd_gs_ms_ns.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(ZDataType) * z_bwd_gs_ms_ns.GetElementSpaceSize()));
v_tensors_device.emplace_back( v_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * v_gs_os_ns.GetElementSpaceSize()));
y_tensors_device.emplace_back( y_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * y_gs_ms_os.GetElementSpaceSize()));
lse_tensors_device.emplace_back( lse_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize()));
qgrad_tensors_device.emplace_back( qgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back( kgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(OutputDataType) * k_gs_ns_ks.GetElementSpaceSize()));
vgrad_tensors_device.emplace_back( vgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(OutputDataType) * v_gs_os_ns.GetElementSpaceSize()));
ygrad_tensors_device.emplace_back( ygrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * y_gs_ms_os.GetElementSpaceSize()));
q_tensors_device.back()->ToDevice(q_gs_ms_ks.data()); q_tensors_device.back()->ToDevice(q_gs_ms_ks.data());
k_tensors_device.back()->ToDevice(k_gs_ns_ks.data()); k_tensors_device.back()->ToDevice(k_gs_ns_ks.data());
...@@ -1258,23 +1278,26 @@ int run(int argc, char* argv[]) ...@@ -1258,23 +1278,26 @@ int run(int argc, char* argv[])
int M = q_tensors[i].GetLengths()[2]; int M = q_tensors[i].GetLengths()[2];
int K = q_tensors[i].GetLengths()[3]; int K = q_tensors[i].GetLengths()[3];
int BatchCount = G0 * G1; int BatchCount = G0 * G1;
Tensor<DataType> qgrad_g_m_k({BatchCount, M, K}); Tensor<OutputDataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<DataType> kgrad_g_n_k({BatchCount, N, K}); Tensor<OutputDataType> kgrad_g_n_k({BatchCount, N, K});
Tensor<DataType> vgrad_g_n_o({BatchCount, N, O}); Tensor<OutputDataType> vgrad_g_n_o({BatchCount, N, O});
Tensor<DataType> sgrad_g_m_n({BatchCount, M, N}); Tensor<InputDataType> sgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_g_m_n({BatchCount, M, N}); Tensor<InputDataType> pgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_drop_g_m_n({BatchCount, M, N}); Tensor<InputDataType> pgrad_drop_g_m_n({BatchCount, M, N});
Tensor<DataType> ygrad_g_m_o({BatchCount, M, O}); Tensor<InputDataType> ygrad_g_m_o({BatchCount, M, O});
ygrad_tensors[i].ForEach([&](auto& self, auto idx) { ygrad_tensors[i].ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
auto ref_gemm_grad = ReferenceGemmGradInstance{}; auto ref_gemm0_grad = ReferenceGemm0GradInstance{};
auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker(); auto ref_gemm0_grad_invoker = ref_gemm0_grad.MakeInvoker();
using RefGemmGradArg = ReferenceGemmGradInstance::Argument; using RefGemm0GradArg = ReferenceGemm0GradInstance::Argument;
auto ref_gemm1_grad = ReferenceGemm1GradInstance{};
auto ref_gemm1_grad_invoker = ref_gemm1_grad.MakeInvoker();
using RefGemm1GradArg = ReferenceGemm1GradInstance::Argument;
// dP = dY * V^T // dP = dY * V^T
auto v_g_o_n = v_g_n_os[i].Transpose({0, 2, 1}); auto v_g_o_n = v_g_n_os[i].Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm0_grad_invoker.Run(RefGemm0GradArg{
ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}}); ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
...@@ -1291,39 +1314,39 @@ int run(int argc, char* argv[]) ...@@ -1291,39 +1314,39 @@ int run(int argc, char* argv[])
ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) * ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) *
ck::type_convert<AccDataType>(y_g_m_os[i](idx_gmo)); ck::type_convert<AccDataType>(y_g_m_os[i](idx_gmo));
} }
self(idx_gmn) = ck::type_convert<DataType>( self(idx_gmn) = ck::type_convert<InputDataType>(
ck::type_convert<AccDataType>(p_g_m_ns[i](idx_gmn)) * ck::type_convert<AccDataType>(p_g_m_ns[i](idx_gmn)) *
(ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y)); (ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y));
}); });
auto p_drop_g_n_m = p_drop_g_m_ns[i].Transpose({0, 2, 1}); auto p_drop_g_n_m = p_drop_g_m_ns[i].Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}}); p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
sgrad_g_m_n, k_g_n_ks[i], qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}}); sgrad_g_m_n, k_g_n_ks[i], qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1}); auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
sgrad_g_n_m, q_g_m_ks[i], kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}}); sgrad_g_n_m, q_g_m_ks[i], kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
Tensor<DataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(), Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<DataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); k_tensors[i].GetStrides());
Tensor<DataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); v_tensors[i].GetStrides());
Tensor<DataType> y_gs_ms_os_host_result(y_tensors[i].GetLengths(), Tensor<InputDataType> y_gs_ms_os_host_result(y_tensors[i].GetLengths(),
y_tensors[i].GetStrides()); y_tensors[i].GetStrides());
Tensor<LSEDataType> lse_gs_ms_host_result(lse_tensors[i].GetLengths(), Tensor<LSEDataType> lse_gs_ms_host_result(lse_tensors[i].GetLengths(),
lse_tensors[i].GetStrides()); lse_tensors[i].GetStrides());
Tensor<DataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(), Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<DataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); k_tensors[i].GetStrides());
Tensor<DataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); v_tensors[i].GetStrides());
Tensor<DataType> y_gs_ms_os_device_result(y_tensors[i].GetLengths(), Tensor<InputDataType> y_gs_ms_os_device_result(y_tensors[i].GetLengths(),
y_tensors[i].GetStrides()); y_tensors[i].GetStrides());
Tensor<LSEDataType> lse_gs_ms_device_result(lse_tensors[i].GetLengths(), Tensor<LSEDataType> lse_gs_ms_device_result(lse_tensors[i].GetLengths(),
lse_tensors[i].GetStrides()); lse_tensors[i].GetStrides());
...@@ -1380,7 +1403,8 @@ int run(int argc, char* argv[]) ...@@ -1380,7 +1403,8 @@ int run(int argc, char* argv[])
double atol = 1e-3; double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01 // when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<DataType, ck::bhalf_t> || std::is_same_v<GemmDataType, ck::bhalf_t>) if(std::is_same_v<InputDataType, ck::bhalf_t> ||
std::is_same_v<GemmDataType, ck::bhalf_t>)
{ {
rtol = 1e-2; rtol = 1e-2;
atol = 1e-2; atol = 1e-2;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment