"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "90e753c97ab4fb8920fe6ddc4a3dc0572a097b47"
Commit f3e61c0a authored by danyao12's avatar danyao12
Browse files

datatype of bwd output can be selected

parent f7e05f9e
...@@ -62,8 +62,9 @@ using Scale = ck::tensor_operation::element_wise::Scale; ...@@ -62,8 +62,9 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough; using QKVElementOp = PassThrough;
using YElementOp = PassThrough; using YElementOp = PassThrough;
using DataType = F16; using InputDataType = BF16;
using GemmDataType = F16; using OutputDataType = F32;
using GemmDataType = BF16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
...@@ -103,7 +104,8 @@ using DeviceGemmInstance = ...@@ -103,7 +104,8 @@ using DeviceGemmInstance =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -161,7 +163,7 @@ using DeviceGemmInstance = ...@@ -161,7 +163,7 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 4, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -171,7 +173,8 @@ using DeviceGemmInstance = ...@@ -171,7 +173,8 @@ using DeviceGemmInstance =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -229,7 +232,7 @@ using DeviceGemmInstance = ...@@ -229,7 +232,7 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 4, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
// using DeviceGemmInstance = // using DeviceGemmInstance =
...@@ -239,7 +242,8 @@ using DeviceGemmInstance = ...@@ -239,7 +242,8 @@ using DeviceGemmInstance =
// NumDimN, // NumDimN,
// NumDimK, // NumDimK,
// NumDimO, // NumDimO,
// DataType, // InputDataType,
// OutputDataType,
// GemmDataType, // GemmDataType,
// ZDataType, // ZDataType,
// LSEDataType, // LSEDataType,
...@@ -297,7 +301,7 @@ using DeviceGemmInstance = ...@@ -297,7 +301,7 @@ using DeviceGemmInstance =
// 1, // CShuffleMXdlPerWavePerShuffle // 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle // 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock // S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 8, // CShuffleBlockTransferScalarPerVector_NPerBlock // 4, // CShuffleBlockTransferScalarPerVector_NPerBlock
// MaskingSpec>; // MaskingSpecialization // MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -307,7 +311,8 @@ using DeviceGemmInstance = ...@@ -307,7 +311,8 @@ using DeviceGemmInstance =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -365,14 +370,14 @@ using DeviceGemmInstance = ...@@ -365,14 +370,14 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle 4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 4, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#endif #endif
// Ref Gemm0: S = alpha * Q * K^T // Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out // fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
DataType, InputDataType,
AccDataType, AccDataType,
AccDataType, AccDataType,
PassThrough, PassThrough,
...@@ -382,13 +387,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -382,13 +387,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Softmax: P = Softmax(S) // Ref Softmax: P = Softmax(S)
// fp32 in, fp16 out // fp32 in, fp16 out
using ReferenceSoftmaxInstance = using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, DataType, AccDataType>; ck::tensor_operation::host::ReferenceSoftmax<AccDataType, InputDataType, AccDataType>;
// Ref Gemm1: Y = P * V // Ref Gemm1: Y = P * V
// fp16 in, fp16 out // fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType, using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
DataType, InputDataType,
DataType, InputDataType,
AccDataType, AccDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -396,16 +401,25 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -396,16 +401,25 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Gemm for backward pass // Ref Gemm for backward pass
// fp16 in, fp16 out // fp16 in, fp16 out
using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType, using ReferenceGemm0GradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
DataType, InputDataType,
DataType, InputDataType,
AccDataType, AccDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>; Scale>;
using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
InputDataType,
OutputDataType,
AccDataType,
PassThrough,
PassThrough,
Scale>;
// Ref dropout // Ref dropout
using ReferenceDropoutInstance = using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ushort, DataType, DataType>; ck::tensor_operation::host::ReferenceDropout<ushort, InputDataType, InputDataType>;
template <typename TensorQ, template <typename TensorQ,
typename TensorK, typename TensorK,
...@@ -482,8 +496,8 @@ int run(int argc, char* argv[]) ...@@ -482,8 +496,8 @@ int run(int argc, char* argv[])
ck::index_t N = 512; ck::index_t N = 512;
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 54; ck::index_t G0 = 4;
ck::index_t G1 = 16; ck::index_t G1 = 6;
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
...@@ -592,12 +606,12 @@ int run(int argc, char* argv[]) ...@@ -592,12 +606,12 @@ int run(int argc, char* argv[])
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M] std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<DataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<InputDataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
...@@ -607,45 +621,45 @@ int run(int argc, char* argv[]) ...@@ -607,45 +621,45 @@ int run(int argc, char* argv[])
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{0}); z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
break; break;
case 2: case 2:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
break; break;
case 3: case 3:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-5, 5});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
break; break;
case 4: case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
break; break;
case 5: case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...] // dO dot O = [0; 1; 2; ...]
break; break;
case 6: case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
...@@ -656,10 +670,10 @@ int run(int argc, char* argv[]) ...@@ -656,10 +670,10 @@ int run(int argc, char* argv[])
// //
break; break;
default: default:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -670,14 +684,14 @@ int run(int argc, char* argv[]) ...@@ -670,14 +684,14 @@ int run(int argc, char* argv[])
// = 0 // = 0
} }
Tensor<DataType> q_g_m_k({BatchCount, M, K}); Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K}); Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N}); Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<DataType> v_g_n_o({BatchCount, N, O}); Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N}); Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<DataType> p_g_m_n({BatchCount, M, N}); Tensor<InputDataType> p_g_m_n({BatchCount, M, N});
Tensor<DataType> p_drop_g_m_n({BatchCount, M, N}); Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
Tensor<DataType> y_g_m_o({BatchCount, M, O}); Tensor<InputDataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M}); Tensor<LSEDataType> lse_g_m({BatchCount, M});
q_gs_ms_ks.ForEach( q_gs_ms_ks.ForEach(
...@@ -688,16 +702,16 @@ int run(int argc, char* argv[]) ...@@ -688,16 +702,16 @@ int run(int argc, char* argv[])
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); }); [&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); });
// qkv gradients have the same descriptor as with qkv // qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem k_device_buf(sizeof(InputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem v_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem v_device_buf(sizeof(InputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem y_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize()); DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize());
DeviceMem qgrad_device_buf(sizeof(DataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem qgrad_device_buf(sizeof(OutputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem kgrad_device_buf(sizeof(DataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem kgrad_device_buf(sizeof(OutputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(DataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem vgrad_device_buf(sizeof(OutputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(DataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
q_device_buf.ToDevice(q_gs_ms_ks.mData.data()); q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data()); k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
...@@ -710,16 +724,16 @@ int run(int argc, char* argv[]) ...@@ -710,16 +724,16 @@ int run(int argc, char* argv[])
// get z matrix // get z matrix
{ {
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()), static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(ygrad_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
...@@ -755,16 +769,16 @@ int run(int argc, char* argv[]) ...@@ -755,16 +769,16 @@ int run(int argc, char* argv[])
} }
// not need output z matrix // not need output z matrix
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
static_cast<DataType*>(q_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(k_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(k_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), // set to nullptr static_cast<ZDataType*>(nullptr), // set to nullptr
static_cast<DataType*>(v_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(v_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(y_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(ygrad_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(ygrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<DataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
...@@ -800,10 +814,13 @@ int run(int argc, char* argv[]) ...@@ -800,10 +814,13 @@ int run(int argc, char* argv[])
// 3x MNK + 2x MNO // 3x MNK + 2x MNO
std::size_t flop = (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount; std::size_t flop = (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE // Q/K/V/Y, dQ/dK/dV/dY, LSE
std::size_t num_btype = (sizeof(DataType) * M * K + sizeof(DataType) * K * N + std::size_t num_btype =
sizeof(DataType) * N * O + sizeof(DataType) * M * O) * (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
size_t(2) * BatchCount + sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) +
sizeof(LSEDataType) * M * BatchCount; sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N +
sizeof(OutputDataType) * N * O) *
BatchCount +
sizeof(LSEDataType) * M * BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -847,14 +864,14 @@ int run(int argc, char* argv[]) ...@@ -847,14 +864,14 @@ int run(int argc, char* argv[])
vgrad_device_buf.SetZero(); vgrad_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
Tensor<DataType> qgrad_g_m_k({BatchCount, M, K}); Tensor<OutputDataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<DataType> kgrad_g_n_k({BatchCount, N, K}); Tensor<OutputDataType> kgrad_g_n_k({BatchCount, N, K});
Tensor<DataType> vgrad_g_n_o({BatchCount, N, O}); Tensor<OutputDataType> vgrad_g_n_o({BatchCount, N, O});
Tensor<DataType> sgrad_g_m_n({BatchCount, M, N}); Tensor<InputDataType> sgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_g_m_n({BatchCount, M, N}); Tensor<InputDataType> pgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_drop_g_m_n({BatchCount, M, N}); Tensor<InputDataType> pgrad_drop_g_m_n({BatchCount, M, N});
Tensor<DataType> ygrad_g_m_o({BatchCount, M, O}); Tensor<InputDataType> ygrad_g_m_o({BatchCount, M, O});
Tensor<DataType> ygrad_dot_y_g_m({BatchCount, M}); Tensor<InputDataType> ygrad_dot_y_g_m({BatchCount, M});
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) { ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
...@@ -870,13 +887,16 @@ int run(int argc, char* argv[]) ...@@ -870,13 +887,16 @@ int run(int argc, char* argv[])
#endif #endif
// Gradients // Gradients
auto ref_gemm_grad = ReferenceGemmGradInstance{}; auto ref_gemm0_grad = ReferenceGemm0GradInstance{};
auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker(); auto ref_gemm0_grad_invoker = ref_gemm0_grad.MakeInvoker();
using RefGemmGradArg = ReferenceGemmGradInstance::Argument; using RefGemm0GradArg = ReferenceGemm0GradInstance::Argument;
auto ref_gemm1_grad = ReferenceGemm1GradInstance{};
auto ref_gemm1_grad_invoker = ref_gemm1_grad.MakeInvoker();
using RefGemm1GradArg = ReferenceGemm1GradInstance::Argument;
// dP_dropout = dY * V^T // dP_dropout = dY * V^T
auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1}); auto v_g_o_n = v_g_n_o.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm0_grad_invoker.Run(RefGemm0GradArg{
ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}}); ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
#if PRINT_HOST #if PRINT_HOST
{ {
...@@ -903,7 +923,7 @@ int run(int argc, char* argv[]) ...@@ -903,7 +923,7 @@ int run(int argc, char* argv[])
ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) * ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) *
ck::type_convert<AccDataType>(y_g_m_o(idx_gmo)); ck::type_convert<AccDataType>(y_g_m_o(idx_gmo));
} }
self(idx_gmn) = ck::type_convert<DataType>( self(idx_gmn) = ck::type_convert<InputDataType>(
ck::type_convert<AccDataType>(p_g_m_n(idx_gmn)) * ck::type_convert<AccDataType>(p_g_m_n(idx_gmn)) *
(ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y)); (ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y));
}); });
...@@ -919,7 +939,7 @@ int run(int argc, char* argv[]) ...@@ -919,7 +939,7 @@ int run(int argc, char* argv[])
#endif #endif
// dV = P_drop^T * dY // dV = P_drop^T * dY
auto p_drop_g_n_m = p_drop_g_m_n.Transpose({0, 2, 1}); auto p_drop_g_n_m = p_drop_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}}); p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}});
#if PRINT_HOST #if PRINT_HOST
{ {
...@@ -931,7 +951,7 @@ int run(int argc, char* argv[]) ...@@ -931,7 +951,7 @@ int run(int argc, char* argv[])
#endif #endif
// dQ = alpha * dS * K // dQ = alpha * dS * K
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
sgrad_g_m_n, k_g_n_k, qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}}); sgrad_g_m_n, k_g_n_k, qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
#if PRINT_HOST #if PRINT_HOST
{ {
...@@ -944,7 +964,7 @@ int run(int argc, char* argv[]) ...@@ -944,7 +964,7 @@ int run(int argc, char* argv[])
// dK = alpha * dS^T * Q // dK = alpha * dS^T * Q
auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1}); auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
sgrad_g_n_m, q_g_m_k, kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}}); sgrad_g_n_m, q_g_m_k, kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
#if PRINT_HOST #if PRINT_HOST
{ {
...@@ -955,13 +975,13 @@ int run(int argc, char* argv[]) ...@@ -955,13 +975,13 @@ int run(int argc, char* argv[])
} }
#endif #endif
Tensor<DataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<DataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<DataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<DataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data()); qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data());
kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data()); kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data());
......
...@@ -61,8 +61,9 @@ using Scale = ck::tensor_operation::element_wise::Scale; ...@@ -61,8 +61,9 @@ using Scale = ck::tensor_operation::element_wise::Scale;
using QKVElementOp = PassThrough; using QKVElementOp = PassThrough;
using YElementOp = PassThrough; using YElementOp = PassThrough;
using DataType = F16; using InputDataType = BF16;
using GemmDataType = F16; using OutputDataType = F32;
using GemmDataType = BF16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
...@@ -102,7 +103,8 @@ using DeviceGemmInstance = ...@@ -102,7 +103,8 @@ using DeviceGemmInstance =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -160,7 +162,7 @@ using DeviceGemmInstance = ...@@ -160,7 +162,7 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 4, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 64) #elif(DIM <= 64)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -170,7 +172,8 @@ using DeviceGemmInstance = ...@@ -170,7 +172,8 @@ using DeviceGemmInstance =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -228,7 +231,7 @@ using DeviceGemmInstance = ...@@ -228,7 +231,7 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 4, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
// using DeviceGemmInstance = // using DeviceGemmInstance =
...@@ -238,7 +241,8 @@ using DeviceGemmInstance = ...@@ -238,7 +241,8 @@ using DeviceGemmInstance =
// NumDimN, // NumDimN,
// NumDimK, // NumDimK,
// NumDimO, // NumDimO,
// DataType, // InputDataType,
// OutputDataType,
// GemmDataType, // GemmDataType,
// ZDataType, // ZDataType,
// LSEDataType, // LSEDataType,
...@@ -296,7 +300,7 @@ using DeviceGemmInstance = ...@@ -296,7 +300,7 @@ using DeviceGemmInstance =
// 1, // CShuffleMXdlPerWavePerShuffle // 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle // 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock // S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 8, // CShuffleBlockTransferScalarPerVector_NPerBlock // 4, // CShuffleBlockTransferScalarPerVector_NPerBlock
// MaskingSpec>; // MaskingSpecialization // MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 128) #elif(DIM <= 128)
using DeviceGemmInstance = using DeviceGemmInstance =
...@@ -306,7 +310,8 @@ using DeviceGemmInstance = ...@@ -306,7 +310,8 @@ using DeviceGemmInstance =
NumDimN, NumDimN,
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, InputDataType,
OutputDataType,
GemmDataType, GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
...@@ -364,14 +369,14 @@ using DeviceGemmInstance = ...@@ -364,14 +369,14 @@ using DeviceGemmInstance =
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle 4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 4, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#endif #endif
// Ref Gemm0: S = alpha * Q * K^T // Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out // fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
DataType, InputDataType,
AccDataType, AccDataType,
AccDataType, AccDataType,
PassThrough, PassThrough,
...@@ -381,13 +386,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -381,13 +386,13 @@ using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Softmax: P = Softmax(S) // Ref Softmax: P = Softmax(S)
// fp32 in, fp16 out // fp32 in, fp16 out
using ReferenceSoftmaxInstance = using ReferenceSoftmaxInstance =
ck::tensor_operation::host::ReferenceSoftmax<AccDataType, DataType, AccDataType>; ck::tensor_operation::host::ReferenceSoftmax<AccDataType, InputDataType, AccDataType>;
// Ref Gemm1: Y = P * V // Ref Gemm1: Y = P * V
// fp16 in, fp16 out // fp16 in, fp16 out
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType, using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
DataType, InputDataType,
DataType, InputDataType,
AccDataType, AccDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -395,16 +400,25 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< ...@@ -395,16 +400,25 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
// Ref Gemm for backward pass // Ref Gemm for backward pass
// fp16 in, fp16 out // fp16 in, fp16 out
using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType, using ReferenceGemm0GradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
DataType, InputDataType,
DataType, InputDataType,
AccDataType, AccDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>; Scale>;
using ReferenceGemm1GradInstance = ck::tensor_operation::host::ReferenceBatchedGemm<InputDataType,
InputDataType,
OutputDataType,
AccDataType,
PassThrough,
PassThrough,
Scale>;
// Ref dropout // Ref dropout
using ReferenceDropoutInstance = using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ushort, DataType, DataType>; ck::tensor_operation::host::ReferenceDropout<ushort, InputDataType, InputDataType>;
template <typename TensorQ, template <typename TensorQ,
typename TensorK, typename TensorK,
...@@ -539,26 +553,26 @@ int run(int argc, char* argv[]) ...@@ -539,26 +553,26 @@ int run(int argc, char* argv[])
std::vector<void*> p_vgrad; std::vector<void*> p_vgrad;
std::vector<const void*> p_ygrad; std::vector<const void*> p_ygrad;
std::vector<Tensor<DataType>> q_g_m_ks; std::vector<Tensor<InputDataType>> q_g_m_ks;
std::vector<Tensor<DataType>> k_g_n_ks; std::vector<Tensor<InputDataType>> k_g_n_ks;
std::vector<Tensor<ZDataType>> z_g_m_ns; std::vector<Tensor<ZDataType>> z_g_m_ns;
std::vector<Tensor<DataType>> v_g_n_os; std::vector<Tensor<InputDataType>> v_g_n_os;
std::vector<Tensor<AccDataType>> s_g_m_ns; std::vector<Tensor<AccDataType>> s_g_m_ns;
std::vector<Tensor<DataType>> p_g_m_ns; std::vector<Tensor<InputDataType>> p_g_m_ns;
std::vector<Tensor<DataType>> y_g_m_os; std::vector<Tensor<InputDataType>> y_g_m_os;
std::vector<Tensor<LSEDataType>> lse_g_ms; std::vector<Tensor<LSEDataType>> lse_g_ms;
std::vector<Tensor<DataType>> p_drop_g_m_ns; std::vector<Tensor<InputDataType>> p_drop_g_m_ns;
std::vector<Tensor<DataType>> q_tensors; std::vector<Tensor<InputDataType>> q_tensors;
std::vector<Tensor<DataType>> k_tensors; std::vector<Tensor<InputDataType>> k_tensors;
std::vector<Tensor<DataType>> v_tensors; std::vector<Tensor<InputDataType>> v_tensors;
std::vector<Tensor<DataType>> y_tensors; std::vector<Tensor<InputDataType>> y_tensors;
std::vector<Tensor<ZDataType>> z_tensors; std::vector<Tensor<ZDataType>> z_tensors;
std::vector<Tensor<LSEDataType>> lse_tensors; std::vector<Tensor<LSEDataType>> lse_tensors;
std::vector<Tensor<DataType>> qgrad_tensors; std::vector<Tensor<OutputDataType>> qgrad_tensors;
std::vector<Tensor<DataType>> kgrad_tensors; std::vector<Tensor<OutputDataType>> kgrad_tensors;
std::vector<Tensor<DataType>> vgrad_tensors; std::vector<Tensor<OutputDataType>> vgrad_tensors;
std::vector<Tensor<DataType>> ygrad_tensors; std::vector<Tensor<InputDataType>> ygrad_tensors;
std::vector<DeviceMemPtr> q_tensors_device; std::vector<DeviceMemPtr> q_tensors_device;
std::vector<DeviceMemPtr> k_tensors_device; std::vector<DeviceMemPtr> k_tensors_device;
...@@ -639,17 +653,19 @@ int run(int argc, char* argv[]) ...@@ -639,17 +653,19 @@ int run(int argc, char* argv[])
int BatchCount = G0 * G1; int BatchCount = G0 * G1;
flop += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount; flop += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE // Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte += (sizeof(DataType) * M * K + sizeof(DataType) * K * N + num_byte += (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
sizeof(DataType) * N * O + sizeof(DataType) * M * O) * sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) +
size_t(2) * BatchCount + sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N +
sizeof(OutputDataType) * N * O) *
BatchCount +
sizeof(LSEDataType) * M * BatchCount; sizeof(LSEDataType) * M * BatchCount;
Tensor<DataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<DataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<DataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<InputDataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<DataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<DataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
if(i < 4) if(i < 4)
{ {
...@@ -660,45 +676,45 @@ int run(int argc, char* argv[]) ...@@ -660,45 +676,45 @@ int run(int argc, char* argv[])
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
} }
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<DataType>{0}); z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<DataType>{-2, 2}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
break; break;
case 2: case 2:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<DataType>{0.0, 1.0}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<DataType>{-0.5, 0.5}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
break; break;
case 3: case 3:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<DataType>{-5, 5}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-5, 5});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
break; break;
case 4: case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{2}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{2});
break; break;
case 5: case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
// dO dot O = [0; 1; 2; ...] // dO dot O = [0; 1; 2; ...]
break; break;
case 6: case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
...@@ -709,10 +725,11 @@ int run(int argc, char* argv[]) ...@@ -709,10 +725,11 @@ int run(int argc, char* argv[])
// //
break; break;
default: default:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<DataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<DataType>{1}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -722,15 +739,15 @@ int run(int argc, char* argv[]) ...@@ -722,15 +739,15 @@ int run(int argc, char* argv[])
// = 0.0039 * ones * (ones - 1) // = 0.0039 * ones * (ones - 1)
// = 0 // = 0
} }
Tensor<DataType> q_g_m_k({BatchCount, M, K}); Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K}); Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N}); Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<DataType> v_g_n_o({BatchCount, N, O}); Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N}); Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
Tensor<DataType> p_g_m_n({BatchCount, M, N}); Tensor<InputDataType> p_g_m_n({BatchCount, M, N});
Tensor<DataType> y_g_m_o({BatchCount, M, O}); Tensor<InputDataType> y_g_m_o({BatchCount, M, O});
Tensor<LSEDataType> lse_g_m({BatchCount, M}); Tensor<LSEDataType> lse_g_m({BatchCount, M});
Tensor<DataType> p_drop_g_m_n({BatchCount, M, N}); Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
q_gs_ms_ks.ForEach([&](auto& self, auto idx) { q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
...@@ -759,25 +776,25 @@ int run(int argc, char* argv[]) ...@@ -759,25 +776,25 @@ int run(int argc, char* argv[])
lse_tensors.push_back(lse_gs_ms); lse_tensors.push_back(lse_gs_ms);
ygrad_tensors.push_back(ygrad_gs_ms_os); ygrad_tensors.push_back(ygrad_gs_ms_os);
q_tensors_device.emplace_back( q_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
k_tensors_device.emplace_back( k_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * k_gs_ns_ks.GetElementSpaceSize()));
z_tensors_device.emplace_back( z_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ZDataType) * z_gs_ms_ns.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(ZDataType) * z_gs_ms_ns.GetElementSpaceSize()));
v_tensors_device.emplace_back( v_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * v_gs_os_ns.GetElementSpaceSize()));
y_tensors_device.emplace_back( y_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * y_gs_ms_os.GetElementSpaceSize()));
lse_tensors_device.emplace_back( lse_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize()));
qgrad_tensors_device.emplace_back( qgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back( kgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * k_gs_ns_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(OutputDataType) * k_gs_ns_ks.GetElementSpaceSize()));
vgrad_tensors_device.emplace_back( vgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * v_gs_os_ns.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(OutputDataType) * v_gs_os_ns.GetElementSpaceSize()));
ygrad_tensors_device.emplace_back( ygrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * y_gs_ms_os.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * y_gs_ms_os.GetElementSpaceSize()));
q_tensors_device.back()->ToDevice(q_gs_ms_ks.data()); q_tensors_device.back()->ToDevice(q_gs_ms_ks.data());
k_tensors_device.back()->ToDevice(k_gs_ns_ks.data()); k_tensors_device.back()->ToDevice(k_gs_ns_ks.data());
z_tensors_device.back()->ToDevice(z_gs_ms_ns.data()); z_tensors_device.back()->ToDevice(z_gs_ms_ns.data());
...@@ -918,23 +935,26 @@ int run(int argc, char* argv[]) ...@@ -918,23 +935,26 @@ int run(int argc, char* argv[])
int M = q_tensors[i].GetLengths()[2]; int M = q_tensors[i].GetLengths()[2];
int K = q_tensors[i].GetLengths()[3]; int K = q_tensors[i].GetLengths()[3];
int BatchCount = G0 * G1; int BatchCount = G0 * G1;
Tensor<DataType> qgrad_g_m_k({BatchCount, M, K}); Tensor<OutputDataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<DataType> kgrad_g_n_k({BatchCount, N, K}); Tensor<OutputDataType> kgrad_g_n_k({BatchCount, N, K});
Tensor<DataType> vgrad_g_n_o({BatchCount, N, O}); Tensor<OutputDataType> vgrad_g_n_o({BatchCount, N, O});
Tensor<DataType> sgrad_g_m_n({BatchCount, M, N}); Tensor<InputDataType> sgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_g_m_n({BatchCount, M, N}); Tensor<InputDataType> pgrad_g_m_n({BatchCount, M, N});
Tensor<DataType> pgrad_drop_g_m_n({BatchCount, M, N}); Tensor<InputDataType> pgrad_drop_g_m_n({BatchCount, M, N});
Tensor<DataType> ygrad_g_m_o({BatchCount, M, O}); Tensor<InputDataType> ygrad_g_m_o({BatchCount, M, O});
ygrad_tensors[i].ForEach([&](auto& self, auto idx) { ygrad_tensors[i].ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
auto ref_gemm_grad = ReferenceGemmGradInstance{}; auto ref_gemm0_grad = ReferenceGemm0GradInstance{};
auto ref_gemm_grad_invoker = ref_gemm_grad.MakeInvoker(); auto ref_gemm0_grad_invoker = ref_gemm0_grad.MakeInvoker();
using RefGemmGradArg = ReferenceGemmGradInstance::Argument; using RefGemm0GradArg = ReferenceGemm0GradInstance::Argument;
auto ref_gemm1_grad = ReferenceGemm1GradInstance{};
auto ref_gemm1_grad_invoker = ref_gemm1_grad.MakeInvoker();
using RefGemm1GradArg = ReferenceGemm1GradInstance::Argument;
// dP = dY * V^T // dP = dY * V^T
auto v_g_o_n = v_g_n_os[i].Transpose({0, 2, 1}); auto v_g_o_n = v_g_n_os[i].Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm0_grad_invoker.Run(RefGemm0GradArg{
ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}}); ygrad_g_m_o, v_g_o_n, pgrad_drop_g_m_n, PassThrough{}, PassThrough{}, Scale{1.f}});
auto ref_dropout = ReferenceDropoutInstance{}; auto ref_dropout = ReferenceDropoutInstance{};
auto ref_dropout_invoker = ref_dropout.MakeInvoker(); auto ref_dropout_invoker = ref_dropout.MakeInvoker();
...@@ -951,33 +971,33 @@ int run(int argc, char* argv[]) ...@@ -951,33 +971,33 @@ int run(int argc, char* argv[])
ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) * ygrad_dot_y += ck::type_convert<AccDataType>(ygrad_g_m_o(idx_gmo)) *
ck::type_convert<AccDataType>(y_g_m_os[i](idx_gmo)); ck::type_convert<AccDataType>(y_g_m_os[i](idx_gmo));
} }
self(idx_gmn) = ck::type_convert<DataType>( self(idx_gmn) = ck::type_convert<InputDataType>(
ck::type_convert<AccDataType>(p_g_m_ns[i](idx_gmn)) * ck::type_convert<AccDataType>(p_g_m_ns[i](idx_gmn)) *
(ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y)); (ck::type_convert<AccDataType>(pgrad_g_m_n(idx_gmn)) - ygrad_dot_y));
}); });
auto p_drop_g_n_m = p_drop_g_m_ns[i].Transpose({0, 2, 1}); auto p_drop_g_n_m = p_drop_g_m_ns[i].Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}}); p_drop_g_n_m, ygrad_g_m_o, vgrad_g_n_o, PassThrough{}, PassThrough{}, Scale{1.0f}});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
sgrad_g_m_n, k_g_n_ks[i], qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}}); sgrad_g_m_n, k_g_n_ks[i], qgrad_g_m_k, PassThrough{}, PassThrough{}, Scale{alpha}});
auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1}); auto sgrad_g_n_m = sgrad_g_m_n.Transpose({0, 2, 1});
ref_gemm_grad_invoker.Run(RefGemmGradArg{ ref_gemm1_grad_invoker.Run(RefGemm1GradArg{
sgrad_g_n_m, q_g_m_ks[i], kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}}); sgrad_g_n_m, q_g_m_ks[i], kgrad_g_n_k, PassThrough{}, PassThrough{}, Scale{alpha}});
Tensor<DataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(), Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<DataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); k_tensors[i].GetStrides());
Tensor<DataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); v_tensors[i].GetStrides());
Tensor<DataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(), Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<DataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); k_tensors[i].GetStrides());
Tensor<DataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); v_tensors[i].GetStrides());
qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data()); qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data());
kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data()); kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data());
......
...@@ -28,7 +28,8 @@ namespace tensor_operation { ...@@ -28,7 +28,8 @@ namespace tensor_operation {
namespace device { namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename DataType, typename InputDataType,
typename OutputDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -53,16 +54,16 @@ __global__ void ...@@ -53,16 +54,16 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1( kernel_batched_multihead_attention_backward_xdl_cshuffle_v1(
const DataType* __restrict__ p_a_grid, const InputDataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid, const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
const DataType* __restrict__ p_b1_grid, const InputDataType* __restrict__ p_b1_grid,
const DataType* __restrict__ p_c_grid, const InputDataType* __restrict__ p_c_grid,
const LSEDataType* __restrict__ p_lse_grid, const LSEDataType* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
DataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op, const AccElementwiseOperation acc_element_op,
...@@ -171,7 +172,8 @@ template <index_t NumDimG, ...@@ -171,7 +172,8 @@ template <index_t NumDimG,
index_t NumDimN, index_t NumDimN,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, // NumDimGemm1N index_t NumDimO, // NumDimGemm1N
typename DataType, typename InputDataType,
typename OutputDataType,
typename GemmDataType, typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
...@@ -597,7 +599,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -597,7 +599,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
DataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -666,16 +669,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -666,16 +669,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument( Argument(
const DataType* p_a_grid, const InputDataType* p_a_grid,
const DataType* p_b_grid, const InputDataType* p_b_grid,
ZDataType* p_z_grid, ZDataType* p_z_grid,
const DataType* p_b1_grid, const InputDataType* p_b1_grid,
const DataType* p_c_grid, // for dS const InputDataType* p_c_grid, // for dS
const LSEDataType* p_lse_grid, const LSEDataType* p_lse_grid,
const DataType* p_ygrad_grid, const InputDataType* p_ygrad_grid,
DataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
DataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
DataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
...@@ -820,16 +823,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -820,16 +823,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
} }
// pointers // pointers
const DataType* p_a_grid_; const InputDataType* p_a_grid_;
const DataType* p_b_grid_; const InputDataType* p_b_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
const DataType* p_b1_grid_; const InputDataType* p_b1_grid_;
const DataType* p_c_grid_; const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_; const LSEDataType* p_lse_grid_;
const DataType* p_ygrad_grid_; const InputDataType* p_ygrad_grid_;
DataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
DataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
DataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -901,7 +904,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -901,7 +904,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1< const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
DataType, InputDataType,
OutputDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
AElementwiseOperation, AElementwiseOperation,
...@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
} }
static auto MakeArgument( static auto MakeArgument(
const DataType* p_a, const InputDataType* p_a,
const DataType* p_b, const InputDataType* p_b,
ZDataType* p_z, ZDataType* p_z,
const DataType* p_b1, const InputDataType* p_b1,
const DataType* p_c, const InputDataType* p_c,
const LSEDataType* p_lse, const LSEDataType* p_lse,
const DataType* p_ygrad_grid, const InputDataType* p_ygrad_grid,
DataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
DataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
DataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
...@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(static_cast<const DataType*>(p_a), return std::make_unique<Argument>(static_cast<const InputDataType*>(p_a),
static_cast<const DataType*>(p_b), static_cast<const InputDataType*>(p_b),
static_cast<ZDataType*>(p_z), static_cast<ZDataType*>(p_z),
static_cast<const DataType*>(p_b1), static_cast<const InputDataType*>(p_b1),
static_cast<const DataType*>(p_c), static_cast<const InputDataType*>(p_c),
static_cast<const LSEDataType*>(p_lse), static_cast<const LSEDataType*>(p_lse),
static_cast<const DataType*>(p_ygrad_grid), static_cast<const InputDataType*>(p_ygrad_grid),
static_cast<DataType*>(p_qgrad_grid), static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<DataType*>(p_kgrad_grid), static_cast<OutputDataType*>(p_kgrad_grid),
static_cast<DataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
p_acc0_biases, // cast in struct Argument p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument p_acc1_biases, // cast in struct Argument
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
......
...@@ -27,7 +27,8 @@ namespace tensor_operation { ...@@ -27,7 +27,8 @@ namespace tensor_operation {
namespace device { namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename DataType, typename InputDataType,
typename OutputDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -52,16 +53,16 @@ __global__ void ...@@ -52,16 +53,16 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif #endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_v2( kernel_batched_multihead_attention_backward_xdl_cshuffle_v2(
const DataType* __restrict__ p_a_grid, const InputDataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid, const InputDataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
const DataType* __restrict__ p_b1_grid, const InputDataType* __restrict__ p_b1_grid,
const DataType* __restrict__ p_c_grid, const InputDataType* __restrict__ p_c_grid,
const LSEDataType* __restrict__ p_lse_grid, const LSEDataType* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
DataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op, const AccElementwiseOperation acc_element_op,
...@@ -170,7 +171,8 @@ template <index_t NumDimG, ...@@ -170,7 +171,8 @@ template <index_t NumDimG,
index_t NumDimN, index_t NumDimN,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, // NumDimGemm1N index_t NumDimO, // NumDimGemm1N
typename DataType, typename InputDataType,
typename OutputDataType,
typename GemmDataType, typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
...@@ -596,7 +598,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -596,7 +598,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
DataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -665,16 +668,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -665,16 +668,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument( Argument(
const DataType* p_a_grid, const InputDataType* p_a_grid,
const DataType* p_b_grid, const InputDataType* p_b_grid,
ZDataType* p_z_grid, ZDataType* p_z_grid,
const DataType* p_b1_grid, const InputDataType* p_b1_grid,
const DataType* p_c_grid, // for dS const InputDataType* p_c_grid, // for dS
const LSEDataType* p_lse_grid, const LSEDataType* p_lse_grid,
const DataType* p_ygrad_grid, const InputDataType* p_ygrad_grid,
DataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
DataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
DataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
...@@ -818,16 +821,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -818,16 +821,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
} }
// pointers // pointers
const DataType* p_a_grid_; const InputDataType* p_a_grid_;
const DataType* p_b_grid_; const InputDataType* p_b_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
const DataType* p_b1_grid_; const InputDataType* p_b1_grid_;
const DataType* p_c_grid_; const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_; const LSEDataType* p_lse_grid_;
const DataType* p_ygrad_grid_; const InputDataType* p_ygrad_grid_;
DataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
DataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
DataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -903,7 +906,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -903,7 +906,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2< const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v2<
GridwiseGemm, GridwiseGemm,
DataType, InputDataType,
OutputDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
AElementwiseOperation, AElementwiseOperation,
...@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1067,16 +1071,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
} }
static auto MakeArgument( static auto MakeArgument(
const DataType* p_a, const InputDataType* p_a,
const DataType* p_b, const InputDataType* p_b,
ZDataType* p_z, ZDataType* p_z,
const DataType* p_b1, const InputDataType* p_b1,
const DataType* p_c, const InputDataType* p_c,
const LSEDataType* p_lse, const LSEDataType* p_lse,
const DataType* p_ygrad_grid, const InputDataType* p_ygrad_grid,
DataType* p_qgrad_grid, OutputDataType* p_qgrad_grid,
DataType* p_kgrad_grid, OutputDataType* p_kgrad_grid,
DataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumAcc0Bias> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumAcc1Bias> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
...@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1182,16 +1186,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
float p_drop, float p_drop,
std::tuple<unsigned long long, unsigned long long> seeds) // override std::tuple<unsigned long long, unsigned long long> seeds) // override
{ {
return std::make_unique<Argument>(static_cast<const DataType*>(p_a), return std::make_unique<Argument>(static_cast<const InputDataType*>(p_a),
static_cast<const DataType*>(p_b), static_cast<const InputDataType*>(p_b),
static_cast<ZDataType*>(p_z), static_cast<ZDataType*>(p_z),
static_cast<const DataType*>(p_b1), static_cast<const InputDataType*>(p_b1),
static_cast<const DataType*>(p_c), static_cast<const InputDataType*>(p_c),
static_cast<const LSEDataType*>(p_lse), static_cast<const LSEDataType*>(p_lse),
static_cast<const DataType*>(p_ygrad_grid), static_cast<const InputDataType*>(p_ygrad_grid),
static_cast<DataType*>(p_qgrad_grid), static_cast<OutputDataType*>(p_qgrad_grid),
static_cast<DataType*>(p_kgrad_grid), static_cast<OutputDataType*>(p_kgrad_grid),
static_cast<DataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
p_acc0_biases, // cast in struct Argument p_acc0_biases, // cast in struct Argument
p_acc1_biases, // cast in struct Argument p_acc1_biases, // cast in struct Argument
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
......
...@@ -150,7 +150,8 @@ template <index_t NumDimG, ...@@ -150,7 +150,8 @@ template <index_t NumDimG,
index_t NumDimN, index_t NumDimN,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, // NumDimGemm1N index_t NumDimO, // NumDimGemm1N
typename DataType, typename InputDataType,
typename OutputDataType,
typename GemmDataType, typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
...@@ -534,7 +535,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -534,7 +535,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
DataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -604,16 +606,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -604,16 +606,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
struct GroupKernelArg struct GroupKernelArg
{ {
// pointers // pointers
const DataType* p_a_grid_; const InputDataType* p_a_grid_;
const DataType* p_b_grid_; const InputDataType* p_b_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
const DataType* p_b1_grid_; const InputDataType* p_b1_grid_;
const DataType* p_c_grid_; const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_; const LSEDataType* p_lse_grid_;
const DataType* p_ygrad_grid_; const InputDataType* p_ygrad_grid_;
DataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
DataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
DataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -712,16 +714,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -712,16 +714,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
grid_size_ = 0; grid_size_ = 0;
for(index_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const DataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
const auto p_b_grid = static_cast<const DataType*>(p_Bs[i]); const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]);
auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]); auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]);
const auto p_b1_grid = static_cast<const DataType*>(p_B1s[i]); const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]);
const auto p_c_grid = static_cast<const DataType*>(p_Cs[i]); const auto p_c_grid = static_cast<const InputDataType*>(p_Cs[i]);
const auto p_lse_grid = static_cast<const LSEDataType*>(p_LSEs[i]); const auto p_lse_grid = static_cast<const LSEDataType*>(p_LSEs[i]);
const auto p_ygrad_grid = static_cast<const DataType*>(p_Ygrads[i]); const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<DataType*>(p_Qgrads[i]); auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<DataType*>(p_Kgrads[i]); auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<DataType*>(p_Vgrads[i]); auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i]; const auto& problem_desc = problem_desc_vec[i];
......
...@@ -150,7 +150,8 @@ template <index_t NumDimG, ...@@ -150,7 +150,8 @@ template <index_t NumDimG,
index_t NumDimN, index_t NumDimN,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, // NumDimGemm1N index_t NumDimO, // NumDimGemm1N
typename DataType, typename InputDataType,
typename OutputDataType,
typename GemmDataType, typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
...@@ -527,7 +528,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -527,7 +528,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
DataType, // TODO: distinguish A/B datatype InputDataType, // TODO: distinguish A/B datatype
OutputDataType,
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -597,16 +599,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -597,16 +599,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
struct GroupKernelArg struct GroupKernelArg
{ {
// pointers // pointers
const DataType* p_a_grid_; const InputDataType* p_a_grid_;
const DataType* p_b_grid_; const InputDataType* p_b_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
const DataType* p_b1_grid_; const InputDataType* p_b1_grid_;
const DataType* p_c_grid_; const InputDataType* p_c_grid_;
const LSEDataType* p_lse_grid_; const LSEDataType* p_lse_grid_;
const DataType* p_ygrad_grid_; const InputDataType* p_ygrad_grid_;
DataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
DataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
DataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -705,16 +707,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -705,16 +707,16 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
grid_size_ = 0; grid_size_ = 0;
for(index_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const DataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
const auto p_b_grid = static_cast<const DataType*>(p_Bs[i]); const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]);
auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]); auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]);
const auto p_b1_grid = static_cast<const DataType*>(p_B1s[i]); const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]);
const auto p_c_grid = static_cast<const DataType*>(p_Cs[i]); const auto p_c_grid = static_cast<const InputDataType*>(p_Cs[i]);
const auto p_lse_grid = static_cast<const LSEDataType*>(p_LSEs[i]); const auto p_lse_grid = static_cast<const LSEDataType*>(p_LSEs[i]);
const auto p_ygrad_grid = static_cast<const DataType*>(p_Ygrads[i]); const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<DataType*>(p_Qgrads[i]); auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<DataType*>(p_Kgrads[i]); auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<DataType*>(p_Vgrads[i]); auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i]; const auto& problem_desc = problem_desc_vec[i];
......
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
namespace ck { namespace ck {
template <typename DataType, template <typename InputDataType,
typename OutputDataType,
typename GemmDataType, typename GemmDataType,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
...@@ -381,7 +382,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -381,7 +382,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence<AK0, MPerBlock, AK1>, Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
DataType, InputDataType,
GemmDataType, GemmDataType,
GridDesc_K0_M_K1, GridDesc_K0_M_K1,
decltype(q_block_desc_k0_m_k1), decltype(q_block_desc_k0_m_k1),
...@@ -406,7 +407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -406,7 +407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence<BK0, NPerBlock, BK1>, Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
DataType, InputDataType,
GemmDataType, GemmDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(k_block_desc_k0_n_k1), decltype(k_block_desc_k0_n_k1),
...@@ -431,7 +432,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -431,7 +432,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence<BK0, NPerBlock, BK1>, Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
DataType, InputDataType,
GemmDataType, GemmDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(v_block_desc_k0_n_k1), decltype(v_block_desc_k0_n_k1),
...@@ -456,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -456,7 +457,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
Sequence<AK0, MPerBlock, AK1>, Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
DataType, InputDataType,
GemmDataType, GemmDataType,
GridDesc_K0_M_K1, GridDesc_K0_M_K1,
decltype(ygrad_block_desc_k0_m_k1), decltype(ygrad_block_desc_k0_m_k1),
...@@ -1043,7 +1044,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1043,7 +1044,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename ElementwiseOp = tensor_operation::element_wise::PassThrough> typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3< using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
DataType, OutputDataType,
decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4), decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4, CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4,
ElementwiseOp, // CElementwiseOperation ElementwiseOp, // CElementwiseOperation
...@@ -1059,7 +1060,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1059,7 +1060,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_> template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_ struct YDotYGrad_M_O_
{ {
static constexpr index_t SrcScalarPerVector = 16 / sizeof(DataType); static constexpr index_t SrcScalarPerVector = 16 / sizeof(InputDataType);
static constexpr auto ThreadClusterLength_O = static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{}; Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{}; static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
...@@ -1234,16 +1235,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1234,16 +1235,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename C0MatrixMask, typename C0MatrixMask,
typename VGradGridDescriptor_N_O, typename VGradGridDescriptor_N_O,
typename YGradGridDesc_O0_M_O1> typename YGradGridDesc_O0_M_O1>
__device__ static void Run(const DataType* __restrict__ p_q_grid, __device__ static void Run(const InputDataType* __restrict__ p_q_grid,
const DataType* __restrict__ p_k_grid, const InputDataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid, unsigned short* __restrict__ p_z_grid,
const DataType* __restrict__ p_v_grid, const InputDataType* __restrict__ p_v_grid,
const DataType* __restrict__ p_y_grid, const InputDataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid, const FloatLSE* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
DataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
...@@ -1723,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -1723,7 +1724,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
// performs for y // performs for y
auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2< auto y_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType, InputDataType,
FloatGemmAcc, FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1), decltype(y_thread_desc_m0_m1_o0_o1),
...@@ -2307,7 +2308,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1 ...@@ -2307,7 +2308,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData, FloatCShuffle, // typename SrcData,
DataType, // typename DstData, OutputDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(qgrad_grid_desc_mblock_mperblock_kblock_kperblock), decltype(qgrad_grid_desc_mblock_mperblock_kblock_kperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
......
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
namespace ck { namespace ck {
template <typename DataType, template <typename InputDataType,
typename OutputDataType,
typename GemmDataType, typename GemmDataType,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
...@@ -457,7 +458,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -457,7 +458,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence<AK0, MPerBlock, AK1>, Sequence<AK0, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
DataType, InputDataType,
GemmDataType, GemmDataType,
GridDesc_K0_M_K1, GridDesc_K0_M_K1,
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
...@@ -482,7 +483,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -482,7 +483,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence<BK0, NPerBlock, BK1>, Sequence<BK0, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
DataType, InputDataType,
GemmDataType, GemmDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -585,7 +586,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -585,7 +586,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence<B1K0, Gemm1NPerBlock, B1K1>, Sequence<B1K0, Gemm1NPerBlock, B1K1>,
B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferThreadClusterArrangeOrder,
DataType, InputDataType,
GemmDataType, GemmDataType,
GridDesc_K0_N_K1, GridDesc_K0_N_K1,
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
...@@ -823,7 +824,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -823,7 +824,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename Gemm2Params_N_O_M::BBlockSliceLengths, typename Gemm2Params_N_O_M::BBlockSliceLengths,
typename Gemm2Params_N_O_M::BThreadClusterLengths, typename Gemm2Params_N_O_M::BThreadClusterLengths,
typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder, typename Gemm2Params_N_O_M::BThreadClusterArrangeOrder,
DataType, InputDataType,
GemmDataType, GemmDataType,
GridDesc_M0_O_M1, GridDesc_M0_O_M1,
decltype(b_block_desc_m0_o_m1), decltype(b_block_desc_m0_o_m1),
...@@ -892,7 +893,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -892,7 +893,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename ElementwiseOp = tensor_operation::element_wise::PassThrough> typename ElementwiseOp = tensor_operation::element_wise::PassThrough>
using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3< using CBlockwiseCopy = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc, FloatGemmAcc,
DataType, OutputDataType,
decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4), decltype(c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4),
CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4, CGridDesc_N0_O0_N1_O1_N2_O2_O3_O4,
ElementwiseOp, // CElementwiseOperation ElementwiseOp, // CElementwiseOperation
...@@ -908,7 +909,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -908,7 +909,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_> template <index_t BlockSize_, index_t BlockSliceLength_M_, index_t BlockSliceLength_O_>
struct YDotYGrad_M_O_ struct YDotYGrad_M_O_
{ {
static constexpr index_t SrcScalarPerVector = 16 / sizeof(DataType); static constexpr index_t SrcScalarPerVector = 16 / sizeof(InputDataType);
static constexpr auto ThreadClusterLength_O = static constexpr auto ThreadClusterLength_O =
Number<BlockSliceLength_O_ / SrcScalarPerVector>{}; Number<BlockSliceLength_O_ / SrcScalarPerVector>{};
static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{}; static constexpr auto ThreadClusterLength_M = Number<BlockSize_ / ThreadClusterLength_O>{};
...@@ -1144,16 +1145,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1144,16 +1145,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename C0MatrixMask, typename C0MatrixMask,
typename VGradGridDescriptor_N_O, typename VGradGridDescriptor_N_O,
typename YGradGridDesc_M0_O_M1> typename YGradGridDesc_M0_O_M1>
__device__ static void Run(const DataType* __restrict__ p_q_grid, __device__ static void Run(const InputDataType* __restrict__ p_q_grid,
const DataType* __restrict__ p_k_grid, const InputDataType* __restrict__ p_k_grid,
unsigned short* __restrict__ p_z_grid, unsigned short* __restrict__ p_z_grid,
const DataType* __restrict__ p_v_grid, const InputDataType* __restrict__ p_v_grid,
const DataType* __restrict__ p_y_grid, const InputDataType* __restrict__ p_y_grid,
const FloatLSE* __restrict__ p_lse_grid, const FloatLSE* __restrict__ p_lse_grid,
const DataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
DataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
DataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
DataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
...@@ -1646,7 +1647,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -1646,7 +1647,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// performs double duty for both y and ygrad // performs double duty for both y and ygrad
auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2< auto yygrad_threadwise_copy = ThreadwiseTensorSliceTransfer_v2<
DataType, InputDataType,
FloatGemmAcc, FloatGemmAcc,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock, YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock,
decltype(y_thread_desc_m0_m1_o0_o1), decltype(y_thread_desc_m0_m1_o0_o1),
...@@ -2257,7 +2258,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -2257,7 +2258,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatCShuffle, // typename SrcData, FloatCShuffle, // typename SrcData,
DataType, // typename DstData, OutputDataType, // typename DstData,
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
decltype(qgrad_grid_desc_mblock_mperblock_kblock_kperblock), decltype(qgrad_grid_desc_mblock_mperblock_kblock_kperblock),
Sequence<0, 1, 2, 3>, // typename DimAccessOrder, Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment