Commit b15eecba authored by letaoqin's avatar letaoqin
Browse files

Merge branch 'mha-train-develop' of...

Merge branch 'mha-train-develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into mha-train-develop
parents 87d1e073 04c206da
...@@ -513,8 +513,10 @@ int run(int argc, char* argv[]) ...@@ -513,8 +513,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; nullptr, // p_acc0_bias;
{}, // std::array<void*, 1> p_acc1_biases; nullptr, // p_acc1_bias;
nullptr,
nullptr,
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
...@@ -558,8 +560,10 @@ int run(int argc, char* argv[]) ...@@ -558,8 +560,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; nullptr, // p_acc0_bias;
{}, // std::array<void*, 1> p_acc1_biases; nullptr, // p_acc1_bias;
nullptr,
nullptr,
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
......
...@@ -518,8 +518,10 @@ int run(int argc, char* argv[]) ...@@ -518,8 +518,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; nullptr, // p_acc0_bias;
{}, // std::array<void*, 1> p_acc1_biases; nullptr, // p_acc1_bias;
nullptr,
nullptr,
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
...@@ -564,8 +566,10 @@ int run(int argc, char* argv[]) ...@@ -564,8 +566,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; nullptr, // p_acc0_bias;
{}, // std::array<void*, 1> p_acc1_biases; nullptr, // p_acc1_bias;
nullptr,
nullptr,
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
......
...@@ -597,8 +597,10 @@ int run(int argc, char* argv[]) ...@@ -597,8 +597,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // p_acc0_biases; nullptr, // p_acc0_biases;
{}, // p_acc1_biases; nullptr, // p_acc1_biases;
nullptr,
nullptr,
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
...@@ -686,8 +688,8 @@ int run(int argc, char* argv[]) ...@@ -686,8 +688,8 @@ int run(int argc, char* argv[])
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()), static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_fwd_device_buf.GetDeviceBuffer()), static_cast<ZDataType*>(z_fwd_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; nullptr, // p_acc0_bias;
{}, // std::array<void*, 1> p_acc1_biases; nullptr, // p_acc1_bias;
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
...@@ -743,8 +745,10 @@ int run(int argc, char* argv[]) ...@@ -743,8 +745,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; nullptr, // p_acc0_bias;
{}, // std::array<void*, 1> p_acc1_biases; nullptr, // p_acc1_bias;
nullptr,
nullptr,
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
......
...@@ -604,6 +604,8 @@ int run(int argc, char* argv[]) ...@@ -604,6 +604,8 @@ int run(int argc, char* argv[])
p_vgrad, p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs, problem_descs,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
...@@ -650,6 +652,8 @@ int run(int argc, char* argv[]) ...@@ -650,6 +652,8 @@ int run(int argc, char* argv[])
p_vgrad, p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs, problem_descs,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
......
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define USING_MASK 0 #define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8. #define DIM 32 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -616,6 +616,8 @@ int run(int argc, char* argv[]) ...@@ -616,6 +616,8 @@ int run(int argc, char* argv[])
p_vgrad, p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs, problem_descs,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
...@@ -663,6 +665,8 @@ int run(int argc, char* argv[]) ...@@ -663,6 +665,8 @@ int run(int argc, char* argv[])
p_vgrad, p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs, problem_descs,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
......
...@@ -728,6 +728,8 @@ int run(int argc, char* argv[]) ...@@ -728,6 +728,8 @@ int run(int argc, char* argv[])
p_vgrad, p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs_bwd, problem_descs_bwd,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
...@@ -815,6 +817,8 @@ int run(int argc, char* argv[]) ...@@ -815,6 +817,8 @@ int run(int argc, char* argv[])
p_vgrad, p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs_bwd, problem_descs_bwd,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
......
...@@ -25,7 +25,7 @@ Kernel outputs: ...@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8. #define DIM 128 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -57,6 +57,7 @@ using BF16 = ck::bhalf_t; ...@@ -57,6 +57,7 @@ using BF16 = ck::bhalf_t;
using F32 = float; using F32 = float;
using U16 = unsigned short; using U16 = unsigned short;
using INT32 = int32_t; using INT32 = int32_t;
using U8 = uint8_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
...@@ -374,8 +375,8 @@ int run(int argc, char* argv[]) ...@@ -374,8 +375,8 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O] : std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d_gs_ms_ns_strides = std::vector<ck::index_t> d0_gs_ms_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // D layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // D layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // D layout [G0, G1, M, N] : std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // D layout [G0, G1, M, N]
...@@ -396,7 +397,7 @@ int run(int argc, char* argv[]) ...@@ -396,7 +397,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<Acc0BiasDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); Tensor<Acc0BiasDataType> d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<InputDataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<InputDataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
...@@ -405,7 +406,7 @@ int run(int argc, char* argv[]) ...@@ -405,7 +406,7 @@ int run(int argc, char* argv[])
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; std::cout << "d0_gs_ms_ns: " << d0_gs_ms_ns.mDesc << std::endl;
std::cout << "z_gs_ms_ns: " << z_gs_ms_ns.mDesc << std::endl; std::cout << "z_gs_ms_ns: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
...@@ -420,36 +421,35 @@ int run(int argc, char* argv[]) ...@@ -420,36 +421,35 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-2, 2}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-2, 2});
// d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break; break;
case 2: case 2:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{-0.5, 0.5}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{-0.5, 0.5});
break; break;
case 3: case 3:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-5, 5}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-5, 5});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break; break;
case 4: case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break; break;
case 5: case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// dO dot O = [0; 1; 2; ...] // dO dot O = [0; 1; 2; ...]
break; break;
case 6: case 6:
...@@ -457,7 +457,7 @@ int run(int argc, char* argv[]) ...@@ -457,7 +457,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -471,7 +471,7 @@ int run(int argc, char* argv[]) ...@@ -471,7 +471,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0,g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0,g1, m, o]
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -485,7 +485,7 @@ int run(int argc, char* argv[]) ...@@ -485,7 +485,7 @@ int run(int argc, char* argv[])
// qkv gradients have the same descriptor as with qkv // qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(InputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem k_device_buf(sizeof(InputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(Acc0BiasDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(Acc0BiasDataType) * d0_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem v_device_buf(sizeof(InputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem v_device_buf(sizeof(InputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
...@@ -494,12 +494,14 @@ int run(int argc, char* argv[]) ...@@ -494,12 +494,14 @@ int run(int argc, char* argv[])
DeviceMem kgrad_device_buf(sizeof(OutputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem kgrad_device_buf(sizeof(OutputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(OutputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem vgrad_device_buf(sizeof(OutputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem d0grad_device_buf(sizeof(Acc0BiasDataType) * d0_gs_ms_ns.mDesc.GetElementSpaceSize());
q_device_buf.ToDevice(q_gs_ms_ks.mData.data()); q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data()); k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); d0_device_buf.ToDevice(d0_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data()); v_device_buf.ToDevice(v_gs_os_ns.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data()); ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
...@@ -516,8 +518,10 @@ int run(int argc, char* argv[]) ...@@ -516,8 +518,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_bias; static_cast<Acc0BiasDataType*>(d0_device_buf.GetDeviceBuffer()), // p_acc0_bias;
nullptr, // p_acc1_bias; nullptr, // p_acc1_bias;
static_cast<Acc0BiasDataType*>(d0grad_device_buf.GetDeviceBuffer()),
nullptr,
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
...@@ -529,10 +533,10 @@ int run(int argc, char* argv[]) ...@@ -529,10 +533,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
d_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_bias_gs_ms_os_lengths, {}, // acc1_bias_gs_ms_os_lengths,
{}, // acc1_bias_gs_ms_os_strides, {}, // acc1_bias_gs_ms_os_strides,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
Scale{alpha}, Scale{alpha},
...@@ -561,8 +565,10 @@ int run(int argc, char* argv[]) ...@@ -561,8 +565,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_bias; static_cast<Acc0BiasDataType*>(d0_device_buf.GetDeviceBuffer()), // p_acc0_bias;
nullptr, // p_acc1_bias; nullptr, // p_acc1_bias;
static_cast<Acc0BiasDataType*>(d0grad_device_buf.GetDeviceBuffer()),
nullptr,
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
...@@ -574,10 +580,10 @@ int run(int argc, char* argv[]) ...@@ -574,10 +580,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
d_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_bias_gs_ms_os_lengths, {}, // acc1_bias_gs_ms_os_lengths,
{}, // acc1_bias_gs_ms_os_strides, {}, // acc1_bias_gs_ms_os_strides,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
Scale{alpha}, Scale{alpha},
...@@ -599,7 +605,7 @@ int run(int argc, char* argv[]) ...@@ -599,7 +605,7 @@ int run(int argc, char* argv[])
(sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N + (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) + sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) +
sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N + sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N +
sizeof(OutputDataType) * N * O + sizeof(Acc0BiasDataType) * M * N) * sizeof(OutputDataType) * N * O + sizeof(Acc0BiasDataType) * M * N * size_t(2)) *
BatchCount + BatchCount +
sizeof(LSEDataType) * M * BatchCount; sizeof(LSEDataType) * M * BatchCount;
...@@ -618,7 +624,7 @@ int run(int argc, char* argv[]) ...@@ -618,7 +624,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> q_g_m_k({BatchCount, M, K}); Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<InputDataType> k_g_n_k({BatchCount, N, K}); Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N}); Tensor<Acc0BiasDataType> d0_g_m_n({G0 * G1, M, N});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N}); Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<InputDataType> v_g_n_o({BatchCount, N, O}); Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N}); Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
...@@ -640,13 +646,13 @@ int run(int argc, char* argv[]) ...@@ -640,13 +646,13 @@ int run(int argc, char* argv[])
v_gs_os_ns.ForEach([&](auto& self, auto idx) { v_gs_os_ns.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
}); });
d_gs_ms_ns.ForEach([&](auto& self, auto idx) { d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
// run fwd again for y, cause z_g_m_n update // run fwd again for y, cause z_g_m_n update
run_attention_fwd_host(q_g_m_k, run_attention_fwd_host(q_g_m_k,
k_g_n_k, k_g_n_k,
d_g_m_n, d0_g_m_n,
v_g_n_o, v_g_n_o,
alpha, alpha,
s_g_m_n, s_g_m_n,
...@@ -783,14 +789,19 @@ int run(int argc, char* argv[]) ...@@ -783,14 +789,19 @@ int run(int argc, char* argv[])
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_host_result(d0_gs_ms_ns_lengths,
d0_gs_ms_ns_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_device_result(d0_gs_ms_ns_lengths,
d0_gs_ms_ns_strides);
qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data()); qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data());
kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data()); kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data());
vgrad_device_buf.FromDevice(vgrad_gs_os_ns_device_result.mData.data()); vgrad_device_buf.FromDevice(vgrad_gs_os_ns_device_result.mData.data());
d0grad_device_buf.FromDevice(d0grad_gs_ms_ns_device_result.mData.data());
// permute // permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) { qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
...@@ -818,6 +829,15 @@ int run(int argc, char* argv[]) ...@@ -818,6 +829,15 @@ int run(int argc, char* argv[])
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]); self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
}); });
d0grad_gs_ms_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = sgrad_g_m_n(g, idx[2], idx[3]);
});
std::cout << "Checking qgrad:\n"; std::cout << "Checking qgrad:\n";
pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData, pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
qgrad_gs_ms_ks_host_result.mData, qgrad_gs_ms_ks_host_result.mData,
...@@ -836,6 +856,12 @@ int run(int argc, char* argv[]) ...@@ -836,6 +856,12 @@ int run(int argc, char* argv[])
"error", "error",
1e-2, 1e-2,
1e-2); 1e-2);
std::cout << "Checking d0grad:\n";
pass &= ck::utils::check_err(d0grad_gs_ms_ns_device_result.mData,
d0grad_gs_ms_ns_host_result.mData,
"error",
1e-2,
1e-2);
} }
return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1); return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
......
...@@ -333,6 +333,7 @@ int run(int argc, char* argv[]) ...@@ -333,6 +333,7 @@ int run(int argc, char* argv[])
std::vector<const void*> p_lse; std::vector<const void*> p_lse;
std::vector<void*> p_qgrad; std::vector<void*> p_qgrad;
std::vector<void*> p_kgrad; std::vector<void*> p_kgrad;
std::vector<void*> p_d0grad;
std::vector<void*> p_vgrad; std::vector<void*> p_vgrad;
std::vector<const void*> p_ygrad; std::vector<const void*> p_ygrad;
...@@ -356,6 +357,7 @@ int run(int argc, char* argv[]) ...@@ -356,6 +357,7 @@ int run(int argc, char* argv[])
std::vector<Tensor<LSEDataType>> lse_tensors; std::vector<Tensor<LSEDataType>> lse_tensors;
std::vector<Tensor<OutputDataType>> qgrad_tensors; std::vector<Tensor<OutputDataType>> qgrad_tensors;
std::vector<Tensor<OutputDataType>> kgrad_tensors; std::vector<Tensor<OutputDataType>> kgrad_tensors;
std::vector<Tensor<Acc0BiasDataType>> d0grad_tensors;
std::vector<Tensor<OutputDataType>> vgrad_tensors; std::vector<Tensor<OutputDataType>> vgrad_tensors;
std::vector<Tensor<InputDataType>> ygrad_tensors; std::vector<Tensor<InputDataType>> ygrad_tensors;
...@@ -369,6 +371,7 @@ int run(int argc, char* argv[]) ...@@ -369,6 +371,7 @@ int run(int argc, char* argv[])
std::vector<DeviceMemPtr> qgrad_tensors_device; std::vector<DeviceMemPtr> qgrad_tensors_device;
std::vector<DeviceMemPtr> ygrad_tensors_device; std::vector<DeviceMemPtr> ygrad_tensors_device;
std::vector<DeviceMemPtr> kgrad_tensors_device; std::vector<DeviceMemPtr> kgrad_tensors_device;
std::vector<DeviceMemPtr> d0grad_tensors_device;
std::vector<DeviceMemPtr> vgrad_tensors_device; std::vector<DeviceMemPtr> vgrad_tensors_device;
std::size_t group_count = 10; std::size_t group_count = 10;
std::size_t flop = 0, num_byte = 0; std::size_t flop = 0, num_byte = 0;
...@@ -445,12 +448,13 @@ int run(int argc, char* argv[]) ...@@ -445,12 +448,13 @@ int run(int argc, char* argv[])
int BatchCount = G0 * G1; int BatchCount = G0 * G1;
flop += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount; flop += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE // Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte += (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N + num_byte +=
sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) + (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N + sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) +
sizeof(OutputDataType) * N * O + sizeof(Acc0BiasDataType) * M * N) * sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N +
BatchCount + sizeof(OutputDataType) * N * O + sizeof(Acc0BiasDataType) * M * N * size_t(2)) *
sizeof(LSEDataType) * M * BatchCount; BatchCount +
sizeof(LSEDataType) * M * BatchCount;
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
...@@ -600,6 +604,8 @@ int run(int argc, char* argv[]) ...@@ -600,6 +604,8 @@ int run(int argc, char* argv[])
std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back( kgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * k_gs_ns_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(OutputDataType) * k_gs_ns_ks.GetElementSpaceSize()));
d0grad_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(Acc0BiasDataType) * d0_gs_ms_ns.GetElementSpaceSize()));
vgrad_tensors_device.emplace_back( vgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * v_gs_os_ns.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(OutputDataType) * v_gs_os_ns.GetElementSpaceSize()));
ygrad_tensors_device.emplace_back( ygrad_tensors_device.emplace_back(
...@@ -619,6 +625,7 @@ int run(int argc, char* argv[]) ...@@ -619,6 +625,7 @@ int run(int argc, char* argv[])
p_y.push_back(y_tensors_device.back()->GetDeviceBuffer()); p_y.push_back(y_tensors_device.back()->GetDeviceBuffer());
p_lse.push_back(lse_tensors_device.back()->GetDeviceBuffer()); p_lse.push_back(lse_tensors_device.back()->GetDeviceBuffer());
p_kgrad.push_back(kgrad_tensors_device.back()->GetDeviceBuffer()); p_kgrad.push_back(kgrad_tensors_device.back()->GetDeviceBuffer());
p_d0grad.push_back(d0grad_tensors_device.back()->GetDeviceBuffer());
p_vgrad.push_back(vgrad_tensors_device.back()->GetDeviceBuffer()); p_vgrad.push_back(vgrad_tensors_device.back()->GetDeviceBuffer());
p_ygrad.push_back(ygrad_tensors_device.back()->GetDeviceBuffer()); p_ygrad.push_back(ygrad_tensors_device.back()->GetDeviceBuffer());
p_qgrad.push_back(qgrad_tensors_device.back()->GetDeviceBuffer()); p_qgrad.push_back(qgrad_tensors_device.back()->GetDeviceBuffer());
...@@ -636,6 +643,8 @@ int run(int argc, char* argv[]) ...@@ -636,6 +643,8 @@ int run(int argc, char* argv[])
p_vgrad, p_vgrad,
p_d0, p_d0,
{}, {},
p_d0grad,
{},
problem_descs, problem_descs,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
...@@ -682,6 +691,8 @@ int run(int argc, char* argv[]) ...@@ -682,6 +691,8 @@ int run(int argc, char* argv[])
p_vgrad, p_vgrad,
p_d0, p_d0,
{}, {},
p_d0grad,
{},
problem_descs, problem_descs,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
...@@ -732,6 +743,7 @@ int run(int argc, char* argv[]) ...@@ -732,6 +743,7 @@ int run(int argc, char* argv[])
lse_tensors_device[i]->ToDevice(lse_tensors[i].data()); lse_tensors_device[i]->ToDevice(lse_tensors[i].data());
qgrad_tensors_device[i]->SetZero(); qgrad_tensors_device[i]->SetZero();
kgrad_tensors_device[i]->SetZero(); kgrad_tensors_device[i]->SetZero();
d0grad_tensors_device[i]->SetZero();
vgrad_tensors_device[i]->SetZero(); vgrad_tensors_device[i]->SetZero();
} }
...@@ -804,6 +816,8 @@ int run(int argc, char* argv[]) ...@@ -804,6 +816,8 @@ int run(int argc, char* argv[])
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); k_tensors[i].GetStrides());
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_host_result(d0_tensors[i].GetLengths(),
d0_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); v_tensors[i].GetStrides());
...@@ -811,11 +825,14 @@ int run(int argc, char* argv[]) ...@@ -811,11 +825,14 @@ int run(int argc, char* argv[])
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); k_tensors[i].GetStrides());
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_device_result(d0_tensors[i].GetLengths(),
d0_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); v_tensors[i].GetStrides());
qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data()); qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data());
kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data()); kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data());
d0grad_tensors_device[i]->FromDevice(d0grad_gs_ms_ns_device_result.data());
vgrad_tensors_device[i]->FromDevice(vgrad_gs_os_ns_device_result.data()); vgrad_tensors_device[i]->FromDevice(vgrad_gs_os_ns_device_result.data());
// permute // permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) { qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
...@@ -834,6 +851,14 @@ int run(int argc, char* argv[]) ...@@ -834,6 +851,14 @@ int run(int argc, char* argv[])
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]); self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
}); });
d0grad_gs_ms_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = sgrad_g_m_n(g, idx[2], idx[3]);
});
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) { vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1 = idx[1];
...@@ -861,6 +886,12 @@ int run(int argc, char* argv[]) ...@@ -861,6 +886,12 @@ int run(int argc, char* argv[])
"error", "error",
1e-2, 1e-2,
1e-2); 1e-2);
std::cout << "Checking d0grad:\n";
pass &= ck::utils::check_err(d0grad_gs_ms_ns_device_result.mData,
d0grad_gs_ms_ns_host_result.mData,
"error",
1e-2,
1e-2);
} }
} }
......
...@@ -71,6 +71,7 @@ __global__ void ...@@ -71,6 +71,7 @@ __global__ void
ignore = p_z_grid; ignore = p_z_grid;
ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3; ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = num_gemm0_m_block_outer_loop;
ignore = batch_count; ignore = batch_count;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = seed; ignore = seed;
...@@ -135,7 +136,7 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator ...@@ -135,7 +136,7 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_m_n_lengths, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_m_n_lengths,
const std::vector<index_t>& z_gs_m_n_strides) const std::vector<index_t>& z_gs_m_n_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_m_n_lengths, z_gs_m_n_strides); return Transform::MakeC0GridDescriptor_M_N(z_gs_m_n_lengths, z_gs_m_n_strides);
} }
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
......
...@@ -123,6 +123,7 @@ __global__ void ...@@ -123,6 +123,7 @@ __global__ void
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -176,11 +177,19 @@ __global__ void ...@@ -176,11 +177,19 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = p_d0_grid + d0_batch_offset; if(p_d0_grid != nullptr)
{
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
if(p_d0grad_grid != nullptr)
{
tmp_p_d0grad_grid = p_d0grad_grid + d0_batch_offset;
}
} }
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
...@@ -197,6 +206,7 @@ __global__ void ...@@ -197,6 +206,7 @@ __global__ void
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -233,6 +243,7 @@ __global__ void ...@@ -233,6 +243,7 @@ __global__ void
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -266,6 +277,7 @@ __global__ void ...@@ -266,6 +277,7 @@ __global__ void
ignore = p_ygrad_grid; ignore = p_ygrad_grid;
ignore = p_qgrad_grid; ignore = p_qgrad_grid;
ignore = p_kgrad_grid; ignore = p_kgrad_grid;
ignore = p_d0grad_grid;
ignore = p_vgrad_grid; ignore = p_vgrad_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
...@@ -579,32 +591,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -579,32 +591,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides) const std::vector<index_t>& z_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
//
// dQ = alpha * dS * K
//
// QGrad in Gemm C position
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths,
const std::vector<index_t>& q_gs_ms_ks_strides)
{
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
}
//
// dK = alpha * dS^T * Q
//
// KGrad in Gemm C position
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths,
const std::vector<index_t>& k_gs_ns_ks_strides)
{
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
} }
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
...@@ -635,7 +622,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -635,7 +622,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths, static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides) const std::vector<index_t>& d_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); return Transform::MakeC0GridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
} }
static auto MakeDGridDescriptor_M(index_t MRaw) static auto MakeDGridDescriptor_M(index_t MRaw)
...@@ -665,7 +652,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -665,7 +652,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using D0GridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
...@@ -673,7 +660,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -673,7 +660,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {})); using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
...@@ -858,6 +845,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -858,6 +845,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias, const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -894,6 +883,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -894,6 +883,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_qgrad_grid_{p_qgrad_grid}, p_qgrad_grid_{p_qgrad_grid},
p_kgrad_grid_{p_kgrad_grid}, p_kgrad_grid_{p_kgrad_grid},
p_vgrad_grid_{p_vgrad_grid}, p_vgrad_grid_{p_vgrad_grid},
p_d0grad_grid_{p_d0grad_grid},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
...@@ -921,7 +911,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -921,7 +911,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{ z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
d_block_2_ctile_map_{ d_block_2_ctile_map_{
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(d_y_grid_desc_m_o_)}, GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(d_y_grid_desc_m_o_)},
...@@ -948,10 +938,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -948,10 +938,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_drop_{p_drop} p_drop_{p_drop}
{ {
// TODO: implement bias addition // TODO: implement bias addition
ignore = p_acc0_bias; ignore = p_d1grad_grid;
ignore = p_acc1_bias; ignore = p_acc1_bias;
ignore = acc0_bias_gs_ms_ns_lengths;
ignore = acc0_bias_gs_ms_ns_strides;
ignore = acc1_bias_gs_ms_gemm1ns_lengths; ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_bias_gs_ms_gemm1ns_strides; ignore = acc1_bias_gs_ms_gemm1ns_strides;
...@@ -962,7 +950,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -962,7 +950,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3_ = d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n); GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N( d0_grid_desc_g_m_n_ = Transform::MakeC0GridDescriptor_G_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]); d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]);
...@@ -1030,6 +1018,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1030,6 +1018,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
D0DataType* p_d0grad_grid_;
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -1191,6 +1180,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1191,6 +1180,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
arg.p_ygrad_grid_, arg.p_ygrad_grid_,
arg.p_qgrad_grid_, arg.p_qgrad_grid_,
arg.p_kgrad_grid_, arg.p_kgrad_grid_,
arg.p_d0grad_grid_,
arg.p_vgrad_grid_, arg.p_vgrad_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -1342,6 +1332,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1342,6 +1332,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias, const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1380,6 +1372,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1380,6 +1372,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_vgrad_grid, p_vgrad_grid,
p_acc0_bias, p_acc0_bias,
p_acc1_bias, p_acc1_bias,
p_d0grad_grid,
p_d1grad_grid,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1422,6 +1416,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1422,6 +1416,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
void* p_vgrad_grid, void* p_vgrad_grid,
const D0DataType* p_acc0_bias, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias, const D1DataType* p_acc1_bias,
void* p_d0grad_grid,
void* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1461,6 +1457,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1461,6 +1457,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<D0DataType*>(p_d0grad_grid),
static_cast<D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
......
...@@ -123,6 +123,7 @@ __global__ void ...@@ -123,6 +123,7 @@ __global__ void
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -176,11 +177,19 @@ __global__ void ...@@ -176,11 +177,19 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = p_d0_grid + d0_batch_offset; if(p_d0_grid != nullptr)
{
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
if(p_d0grad_grid != nullptr)
{
tmp_p_d0grad_grid = p_d0grad_grid + d0_batch_offset;
}
} }
if constexpr(Deterministic) if constexpr(Deterministic)
...@@ -198,6 +207,7 @@ __global__ void ...@@ -198,6 +207,7 @@ __global__ void
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -234,6 +244,7 @@ __global__ void ...@@ -234,6 +244,7 @@ __global__ void
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -267,6 +278,7 @@ __global__ void ...@@ -267,6 +278,7 @@ __global__ void
ignore = p_ygrad_grid; ignore = p_ygrad_grid;
ignore = p_qgrad_grid; ignore = p_qgrad_grid;
ignore = p_kgrad_grid; ignore = p_kgrad_grid;
ignore = p_d0grad_grid;
ignore = p_vgrad_grid; ignore = p_vgrad_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
...@@ -587,39 +599,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -587,39 +599,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths, static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides) const std::vector<index_t>& d_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); return Transform::MakeC0GridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
} }
// Z in Gemm0 C position // Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides) const std::vector<index_t>& z_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
//
// dQ = alpha * dS * K
//
// QGrad in Gemm C position
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths,
const std::vector<index_t>& q_gs_ms_ks_strides)
{
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
}
//
// dK = alpha * dS^T * Q
//
// KGrad in Gemm C position
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths,
const std::vector<index_t>& k_gs_ns_ks_strides)
{
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
} }
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
...@@ -674,7 +661,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -674,7 +661,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using D0GridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
...@@ -682,7 +669,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -682,7 +669,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {})); using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
...@@ -874,6 +861,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -874,6 +861,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias, const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -910,6 +899,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -910,6 +899,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_qgrad_grid_{p_qgrad_grid}, p_qgrad_grid_{p_qgrad_grid},
p_kgrad_grid_{p_kgrad_grid}, p_kgrad_grid_{p_kgrad_grid},
p_vgrad_grid_{p_vgrad_grid}, p_vgrad_grid_{p_vgrad_grid},
p_d0grad_grid_{p_d0grad_grid},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
...@@ -936,7 +926,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -936,7 +926,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{ z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
d_block_2_ctile_map_{ d_block_2_ctile_map_{
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(d_y_grid_desc_m_o_)}, GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(d_y_grid_desc_m_o_)},
...@@ -964,6 +954,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -964,6 +954,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{ {
// TODO: implement bias addition // TODO: implement bias addition
ignore = p_acc1_bias; ignore = p_acc1_bias;
ignore = p_d1grad_grid;
ignore = acc1_bias_gs_ms_gemm1ns_lengths; ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_bias_gs_ms_gemm1ns_strides; ignore = acc1_bias_gs_ms_gemm1ns_strides;
...@@ -974,7 +965,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -974,7 +965,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3_ = d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n); GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N( d0_grid_desc_g_m_n_ = Transform::MakeC0GridDescriptor_G_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]); d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]);
...@@ -1042,6 +1033,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1042,6 +1033,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
D0DataType* p_d0grad_grid_;
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -1207,6 +1199,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1207,6 +1199,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
arg.p_ygrad_grid_, arg.p_ygrad_grid_,
arg.p_qgrad_grid_, arg.p_qgrad_grid_,
arg.p_kgrad_grid_, arg.p_kgrad_grid_,
arg.p_d0grad_grid_,
arg.p_vgrad_grid_, arg.p_vgrad_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -1374,6 +1367,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1374,6 +1367,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias, const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1412,6 +1407,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1412,6 +1407,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_vgrad_grid, p_vgrad_grid,
p_acc0_bias, p_acc0_bias,
p_acc1_bias, p_acc1_bias,
p_d0grad_grid,
p_d1grad_grid,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1454,6 +1451,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1454,6 +1451,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
void* p_vgrad_grid, void* p_vgrad_grid,
const void* p_acc0_bias, const void* p_acc0_bias,
const void* p_acc1_bias, const void* p_acc1_bias,
void* p_d0grad_grid,
void* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1493,6 +1492,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1493,6 +1492,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<D0DataType*>(p_d0grad_grid),
static_cast<D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
......
...@@ -65,6 +65,7 @@ __global__ void ...@@ -65,6 +65,7 @@ __global__ void
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -120,11 +121,19 @@ __global__ void ...@@ -120,11 +121,19 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = p_d0_grid + d0_batch_offset; if(p_d0_grid != nullptr)
{
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
if(p_d0grad_grid != nullptr)
{
tmp_p_d0grad_grid = p_d0grad_grid + d0_batch_offset;
}
} }
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
...@@ -141,6 +150,7 @@ __global__ void ...@@ -141,6 +150,7 @@ __global__ void
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -178,6 +188,7 @@ __global__ void ...@@ -178,6 +188,7 @@ __global__ void
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -212,6 +223,7 @@ __global__ void ...@@ -212,6 +223,7 @@ __global__ void
ignore = p_ygrad_grid; ignore = p_ygrad_grid;
ignore = p_qgrad_grid; ignore = p_qgrad_grid;
ignore = p_kgrad_grid; ignore = p_kgrad_grid;
ignore = p_d0grad_grid;
ignore = p_vgrad_grid; ignore = p_vgrad_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
...@@ -514,32 +526,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -514,32 +526,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides) const std::vector<index_t>& z_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
//
// dQ = alpha * dS * K
//
// QGrad in Gemm C position
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths,
const std::vector<index_t>& q_gs_ms_ks_strides)
{
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
}
//
// dK = alpha * dS^T * Q
//
// KGrad in Gemm C position
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths,
const std::vector<index_t>& k_gs_ns_ks_strides)
{
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
} }
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
...@@ -570,12 +557,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -570,12 +557,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths, static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides) const std::vector<index_t>& d_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); return Transform::MakeC0GridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using D0GridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
...@@ -583,7 +570,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -583,7 +570,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {})); using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {})); using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
...@@ -755,6 +742,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -755,6 +742,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias, const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -790,6 +779,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -790,6 +779,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_qgrad_grid_{p_qgrad_grid}, p_qgrad_grid_{p_qgrad_grid},
p_kgrad_grid_{p_kgrad_grid}, p_kgrad_grid_{p_kgrad_grid},
p_vgrad_grid_{p_vgrad_grid}, p_vgrad_grid_{p_vgrad_grid},
p_d0grad_grid_{p_d0grad_grid},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
...@@ -814,7 +804,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -814,7 +804,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{ z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
y_grid_desc_mblock_mperblock_oblock_operblock_{}, y_grid_desc_mblock_mperblock_oblock_operblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
...@@ -839,10 +829,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -839,10 +829,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_drop_{p_drop} p_drop_{p_drop}
{ {
// TODO: implement bias addition // TODO: implement bias addition
ignore = p_acc0_bias; ignore = p_d1grad_grid;
ignore = p_acc1_bias; ignore = p_acc1_bias;
ignore = acc0_bias_gs_ms_ns_lengths;
ignore = acc0_bias_gs_ms_ns_strides;
ignore = acc1_bias_gs_ms_gemm1ns_lengths; ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_bias_gs_ms_gemm1ns_strides; ignore = acc1_bias_gs_ms_gemm1ns_strides;
...@@ -862,7 +850,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -862,7 +850,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3_ = d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n); GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N( d0_grid_desc_g_m_n_ = Transform::MakeC0GridDescriptor_G_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]); d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]);
...@@ -926,6 +914,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -926,6 +914,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
D0DataType* p_d0grad_grid_;
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -1049,6 +1038,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1049,6 +1038,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
arg.p_ygrad_grid_, arg.p_ygrad_grid_,
arg.p_qgrad_grid_, arg.p_qgrad_grid_,
arg.p_kgrad_grid_, arg.p_kgrad_grid_,
arg.p_d0grad_grid_,
arg.p_vgrad_grid_, arg.p_vgrad_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -1200,6 +1190,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1200,6 +1190,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias, const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1237,6 +1229,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1237,6 +1229,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_vgrad_grid, p_vgrad_grid,
p_acc0_bias, p_acc0_bias,
p_acc1_bias, p_acc1_bias,
p_d0grad_grid,
p_d1grad_grid,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1278,6 +1272,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1278,6 +1272,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
void* p_vgrad_grid, void* p_vgrad_grid,
const D0DataType* p_acc0_bias, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias, const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1316,6 +1312,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1316,6 +1312,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<const D0DataType*>(p_d0grad_grid),
static_cast<const D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
......
...@@ -65,6 +65,7 @@ __global__ void ...@@ -65,6 +65,7 @@ __global__ void
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -120,13 +121,21 @@ __global__ void ...@@ -120,13 +121,21 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = p_d0_grid + d0_batch_offset; if(p_d0_grid != nullptr)
{
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
if(p_d0grad_grid != nullptr)
{
tmp_p_d0grad_grid = p_d0grad_grid + d0_batch_offset;
}
} }
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
for(index_t i = 0; i < nblock; i++) for(index_t i = 0; i < nblock; i++)
...@@ -142,6 +151,7 @@ __global__ void ...@@ -142,6 +151,7 @@ __global__ void
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -179,6 +189,7 @@ __global__ void ...@@ -179,6 +189,7 @@ __global__ void
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -213,6 +224,7 @@ __global__ void ...@@ -213,6 +224,7 @@ __global__ void
ignore = p_ygrad_grid; ignore = p_ygrad_grid;
ignore = p_qgrad_grid; ignore = p_qgrad_grid;
ignore = p_kgrad_grid; ignore = p_kgrad_grid;
ignore = p_d0grad_grid;
ignore = p_vgrad_grid; ignore = p_vgrad_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
...@@ -522,39 +534,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -522,39 +534,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths, static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides) const std::vector<index_t>& d_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); return Transform::MakeC0GridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
} }
// Z in Gemm0 C position // Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides) const std::vector<index_t>& z_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
//
// dQ = alpha * dS * K
//
// QGrad in Gemm C position
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths,
const std::vector<index_t>& q_gs_ms_ks_strides)
{
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
}
//
// dK = alpha * dS^T * Q
//
// KGrad in Gemm C position
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths,
const std::vector<index_t>& k_gs_ns_ks_strides)
{
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
} }
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
...@@ -584,7 +571,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -584,7 +571,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using D0GridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
...@@ -592,7 +579,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -592,7 +579,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {})); using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {})); using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
...@@ -771,6 +758,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -771,6 +758,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias, const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -806,6 +795,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -806,6 +795,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_qgrad_grid_{p_qgrad_grid}, p_qgrad_grid_{p_qgrad_grid},
p_kgrad_grid_{p_kgrad_grid}, p_kgrad_grid_{p_kgrad_grid},
p_vgrad_grid_{p_vgrad_grid}, p_vgrad_grid_{p_vgrad_grid},
p_d0grad_grid_{p_d0grad_grid},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
...@@ -829,7 +819,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -829,7 +819,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{ z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
y_grid_desc_mblock_mperblock_oblock_operblock_{}, y_grid_desc_mblock_mperblock_oblock_operblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
...@@ -855,6 +845,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -855,6 +845,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
// TODO: implement bias addition // TODO: implement bias addition
ignore = p_acc1_bias; ignore = p_acc1_bias;
ignore = p_d1grad_grid;
ignore = acc1_bias_gs_ms_gemm1ns_lengths; ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_bias_gs_ms_gemm1ns_strides; ignore = acc1_bias_gs_ms_gemm1ns_strides;
...@@ -875,7 +866,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -875,7 +866,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3_ = d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n); GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N( d0_grid_desc_g_m_n_ = Transform::MakeC0GridDescriptor_G_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]); d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]);
...@@ -925,6 +916,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -925,6 +916,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0) std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0)
<< ", " << ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", " << ", " << ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", "
<< ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n'; << ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n';
std::cout << "d0_grid_desc_g_m_n_: " << d0_grid_desc_g_m_n_.GetLength(I0) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I1) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I2) << '\n';
} }
// pointers // pointers
...@@ -939,6 +934,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -939,6 +934,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
D0DataType* p_d0grad_grid_;
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -1066,6 +1062,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1066,6 +1062,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
arg.p_ygrad_grid_, arg.p_ygrad_grid_,
arg.p_qgrad_grid_, arg.p_qgrad_grid_,
arg.p_kgrad_grid_, arg.p_kgrad_grid_,
arg.p_d0grad_grid_,
arg.p_vgrad_grid_, arg.p_vgrad_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -1233,6 +1230,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1233,6 +1230,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias, const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1270,6 +1269,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1270,6 +1269,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_vgrad_grid, p_vgrad_grid,
p_acc0_bias, p_acc0_bias,
p_acc1_bias, p_acc1_bias,
p_d0grad_grid,
p_d1grad_grid,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1311,6 +1312,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1311,6 +1312,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
void* p_vgrad_grid, void* p_vgrad_grid,
const void* p_acc0_bias, const void* p_acc0_bias,
const void* p_acc1_bias, const void* p_acc1_bias,
void* p_d0grad_grid,
void* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1349,6 +1352,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1349,6 +1352,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<D0DataType*>(p_d0grad_grid),
static_cast<D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
......
...@@ -162,13 +162,16 @@ __global__ void ...@@ -162,13 +162,16 @@ __global__ void
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const long_index_t d0_batch_offset = const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>( __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
if(arg_ptr[group_id].p_d0_grid_ != nullptr)
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset; tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
if(arg_ptr[group_id].p_d0grad_grid_)
tmp_p_d0grad_grid = arg_ptr[group_id].p_d0grad_grid_ + d0_batch_offset;
} }
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
...@@ -185,6 +188,7 @@ __global__ void ...@@ -185,6 +188,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -222,6 +226,7 @@ __global__ void ...@@ -222,6 +226,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -540,7 +545,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -540,7 +545,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides) const std::vector<index_t>& z_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
} }
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
...@@ -572,8 +577,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -572,8 +577,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
static auto static auto
...@@ -581,8 +586,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -581,8 +586,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
static auto MakeDGridDescriptor_M(index_t MRaw) static auto MakeDGridDescriptor_M(index_t MRaw)
...@@ -806,6 +811,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -806,6 +811,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const InputDataType* p_ygrad_grid_; const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
D0DataType* p_d0grad_grid_;
OutputDataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
...@@ -878,6 +884,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -878,6 +884,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -911,7 +919,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -911,7 +919,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) || (group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) && ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size())) 0 == p_acc1_bias_vec.size() &&
(group_count_ == ck::type_convert<ck::index_t>(p_d0grads.size()) ||
ck::type_convert<ck::index_t>(p_d0grads.size() == 0)) &&
0 == p_d1grads.size()))
{ {
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size"); throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
} }
...@@ -937,7 +948,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -937,7 +948,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]); const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]); auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]); auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]); auto p_d0grad_grid =
(ck::type_convert<ck::index_t>(p_d0grads.size()) == group_count_)
? static_cast<D0DataType*>(p_d0grads[i])
: nullptr;
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i]; const auto& problem_desc = problem_desc_vec[i];
...@@ -983,7 +998,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -983,7 +998,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides); problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N( const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides); tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto z_grid_desc_g_m_n = Transform::MakeC0GridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides); problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K( const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths, problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
...@@ -1054,6 +1069,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1054,6 +1069,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_ygrad_grid, p_ygrad_grid,
p_qgrad_grid, p_qgrad_grid,
p_kgrad_grid, p_kgrad_grid,
p_d0grad_grid,
p_vgrad_grid, p_vgrad_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
...@@ -1370,6 +1386,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1370,6 +1386,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1392,6 +1410,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1392,6 +1410,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_Vgrads, p_Vgrads,
p_acc0_bias_vec, p_acc0_bias_vec,
p_acc1_bias_vec, p_acc1_bias_vec,
p_d0grads,
p_d1grads,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -1420,6 +1440,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1420,6 +1440,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1442,6 +1464,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1442,6 +1464,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_Vgrads, p_Vgrads,
p_acc0_bias_vec, // cast in struct Argument p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias_vec, // cast in struct Argument p_acc1_bias_vec, // cast in struct Argument
p_d0grads,
p_d1grads,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
......
...@@ -160,13 +160,17 @@ __global__ void ...@@ -160,13 +160,17 @@ __global__ void
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const long_index_t d0_batch_offset = const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>( __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset; if(arg_ptr[group_id].p_d0_grid_ != nullptr)
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
if(arg_ptr[group_id].p_d0grad_grid_)
tmp_p_d0grad_grid = arg_ptr[group_id].p_d0grad_grid_ + d0_batch_offset;
} }
if constexpr(Deterministic) if constexpr(Deterministic)
...@@ -184,6 +188,7 @@ __global__ void ...@@ -184,6 +188,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -221,6 +226,7 @@ __global__ void ...@@ -221,6 +226,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -602,7 +608,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -602,7 +608,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides) const std::vector<index_t>& z_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
} }
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
...@@ -634,8 +640,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -634,8 +640,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
static auto static auto
...@@ -643,8 +649,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -643,8 +649,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
static auto MakeDGridDescriptor_M(index_t MRaw) static auto MakeDGridDescriptor_M(index_t MRaw)
...@@ -682,7 +688,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -682,7 +688,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using D0GridDesc_G_M_N = decltype(MakeD0GridDescriptor_G_M_N({}, {})); using D0GridDesc_G_M_N = decltype(MakeD0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {})); using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {})); using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
...@@ -876,6 +882,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -876,6 +882,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const InputDataType* p_ygrad_grid_; const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
D0DataType* p_d0grad_grid_;
OutputDataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
...@@ -948,6 +955,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -948,6 +955,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -981,7 +990,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -981,7 +990,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) || (group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) && ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size())) 0 == p_acc1_bias_vec.size() &&
(group_count_ == ck::type_convert<ck::index_t>(p_d0grads.size()) ||
ck::type_convert<ck::index_t>(p_d0grads.size() == 0)) &&
0 == p_d1grads.size()))
{ {
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size"); throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
} }
...@@ -1007,7 +1019,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1007,7 +1019,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]); const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]); auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]); auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]); auto p_d0grad_grid =
(ck::type_convert<ck::index_t>(p_d0grads.size()) == group_count_)
? static_cast<D0DataType*>(p_d0grads[i])
: nullptr;
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i]; const auto& problem_desc = problem_desc_vec[i];
...@@ -1053,7 +1069,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1053,7 +1069,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides); problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N( const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides); tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto z_grid_desc_g_m_n = Transform::MakeC0GridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides); problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K( const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths, problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
...@@ -1124,6 +1140,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1124,6 +1140,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_ygrad_grid, p_ygrad_grid,
p_qgrad_grid, p_qgrad_grid,
p_kgrad_grid, p_kgrad_grid,
p_d0grad_grid,
p_vgrad_grid, p_vgrad_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
...@@ -1445,6 +1462,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1445,6 +1462,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1467,6 +1486,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1467,6 +1486,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_Vgrads, p_Vgrads,
p_acc0_bias_vec, p_acc0_bias_vec,
p_acc1_bias_vec, p_acc1_bias_vec,
p_d0grads,
p_d1grads,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -1495,6 +1516,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1495,6 +1516,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1517,6 +1540,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1517,6 +1540,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_Vgrads, p_Vgrads,
p_acc0_bias_vec, // cast in struct Argument p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias_vec, // cast in struct Argument p_acc1_bias_vec, // cast in struct Argument
p_d0grads,
p_d1grads,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
......
...@@ -103,13 +103,17 @@ __global__ void ...@@ -103,13 +103,17 @@ __global__ void
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const long_index_t d0_batch_offset = const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>( __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset; if(arg_ptr[group_id].p_d0_grid_ != nullptr)
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
if(arg_ptr[group_id].p_d0grad_grid_)
tmp_p_d0grad_grid = arg_ptr[group_id].p_d0grad_grid_ + d0_batch_offset;
} }
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
...@@ -126,6 +130,7 @@ __global__ void ...@@ -126,6 +130,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -164,6 +169,7 @@ __global__ void ...@@ -164,6 +169,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -471,7 +477,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -471,7 +477,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides) const std::vector<index_t>& z_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
} }
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
...@@ -503,8 +509,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -503,8 +509,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
static auto static auto
...@@ -512,8 +518,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -512,8 +518,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
...@@ -526,7 +532,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -526,7 +532,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using D0GridDesc_G_M_N = decltype(MakeD0GridDescriptor_G_M_N({}, {})); using D0GridDesc_G_M_N = decltype(MakeD0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {})); using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {})); using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
...@@ -696,6 +702,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -696,6 +702,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const InputDataType* p_ygrad_grid_; const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
D0DataType* p_d0grad_grid_;
OutputDataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
...@@ -760,6 +767,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -760,6 +767,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -792,7 +801,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -792,7 +801,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) && group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) || (group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) && ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size())) 0 == p_acc1_bias_vec.size() &&
(group_count_ == ck::type_convert<ck::index_t>(p_d0grads.size()) ||
ck::type_convert<ck::index_t>(p_d0grads.size() == 0)) &&
0 == p_d1grads.size()))
{ {
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size"); throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
} }
...@@ -816,7 +828,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -816,7 +828,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]); const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]); auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]); auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]); auto p_d0grad_grid =
(ck::type_convert<ck::index_t>(p_d0grads.size()) == group_count_)
? static_cast<D0DataType*>(p_d0grads[i])
: nullptr;
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i]; const auto& problem_desc = problem_desc_vec[i];
...@@ -862,7 +878,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -862,7 +878,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides); problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N( const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides); tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto z_grid_desc_g_m_n = Transform::MakeC0GridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides); problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K( const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths, problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
...@@ -925,6 +941,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -925,6 +941,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_ygrad_grid, p_ygrad_grid,
p_qgrad_grid, p_qgrad_grid,
p_kgrad_grid, p_kgrad_grid,
p_d0grad_grid,
p_vgrad_grid, p_vgrad_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
...@@ -1214,6 +1231,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1214,6 +1231,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1235,6 +1254,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1235,6 +1254,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Vgrads, p_Vgrads,
p_acc0_bias_vec, p_acc0_bias_vec,
p_acc1_bias_vec, p_acc1_bias_vec,
p_d0grads,
p_d1grads,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -1262,6 +1283,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1262,6 +1283,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1283,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1283,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Vgrads, p_Vgrads,
p_acc0_bias_vec, // cast in struct Argument p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias_vec, // cast in struct Argument p_acc1_bias_vec, // cast in struct Argument
p_d0grads,
p_d1grads,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
......
...@@ -102,13 +102,16 @@ __global__ void ...@@ -102,13 +102,16 @@ __global__ void
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const long_index_t d0_batch_offset = const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>( __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
if(arg_ptr[group_id].p_d0_grid_ != nullptr)
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset; tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
if(arg_ptr[group_id].p_d0grad_grid_)
tmp_p_d0grad_grid = arg_ptr[group_id].p_d0grad_grid_ + d0_batch_offset;
} }
if constexpr(Deterministic) if constexpr(Deterministic)
...@@ -126,6 +129,7 @@ __global__ void ...@@ -126,6 +129,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -164,6 +168,7 @@ __global__ void ...@@ -164,6 +168,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -534,7 +539,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -534,7 +539,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths, static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides) const std::vector<index_t>& z_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
} }
static auto MakeLSEGridDescriptor_M(index_t MRaw) static auto MakeLSEGridDescriptor_M(index_t MRaw)
...@@ -566,8 +571,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -566,8 +571,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
static auto static auto
...@@ -575,8 +580,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -575,8 +580,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides) const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths, return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides); acc0_bias_gs_ms_ns_strides);
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
...@@ -589,7 +594,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -589,7 +594,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using D0GridDesc_G_M_N = decltype(MakeD0GridDescriptor_G_M_N({}, {})); using D0GridDesc_G_M_N = decltype(MakeD0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {})); using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {})); using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
...@@ -767,6 +772,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -767,6 +772,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const InputDataType* p_ygrad_grid_; const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
D0DataType* p_d0grad_grid_;
OutputDataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
...@@ -831,6 +837,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -831,6 +837,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -863,7 +871,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -863,7 +871,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) && group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) || (group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) && ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size())) 0 == p_acc1_bias_vec.size() &&
(group_count_ == ck::type_convert<ck::index_t>(p_d0grads.size()) ||
ck::type_convert<ck::index_t>(p_d0grads.size() == 0)) &&
0 == p_d1grads.size()))
{ {
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size"); throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
} }
...@@ -887,7 +898,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -887,7 +898,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]); const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]); auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]); auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]); auto p_d0grad_grid =
(ck::type_convert<ck::index_t>(p_d0grads.size()) == group_count_)
? static_cast<D0DataType*>(p_d0grads[i])
: nullptr;
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i]; const auto& problem_desc = problem_desc_vec[i];
...@@ -933,7 +948,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -933,7 +948,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides); problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N( const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides); tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto z_grid_desc_g_m_n = Transform::MakeC0GridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides); problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K( const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths, problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
...@@ -996,6 +1011,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -996,6 +1011,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_ygrad_grid, p_ygrad_grid,
p_qgrad_grid, p_qgrad_grid,
p_kgrad_grid, p_kgrad_grid,
p_d0grad_grid,
p_vgrad_grid, p_vgrad_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
...@@ -1290,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1290,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1311,6 +1329,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1311,6 +1329,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Vgrads, p_Vgrads,
p_acc0_bias_vec, p_acc0_bias_vec,
p_acc1_bias_vec, p_acc1_bias_vec,
p_d0grads,
p_d1grads,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -1338,6 +1358,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1338,6 +1358,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1359,6 +1381,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1359,6 +1381,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Vgrads, p_Vgrads,
p_acc0_bias_vec, // cast in struct Argument p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias_vec, // cast in struct Argument p_acc1_bias_vec, // cast in struct Argument
p_d0grads,
p_d1grads,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
......
...@@ -119,6 +119,15 @@ struct GemmGemmPadder ...@@ -119,6 +119,15 @@ struct GemmGemmPadder
c_desc_mraw_nraw, make_tuple(MPerTile_, OPerTile_), Sequence<PadM, PadO>{}); c_desc_mraw_nraw, make_tuple(MPerTile_, OPerTile_), Sequence<PadM, PadO>{});
} }
// C[M, Gemm1N] = C[M, N]
template <typename C0Desc_MRaw_NRaw>
__host__ __device__ constexpr auto
PadC0Descriptor_M_N(const C0Desc_MRaw_NRaw& c_desc_mraw_nraw) const
{
return PadTensorDescriptor(
c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
}
MPerTileType MPerTile_; MPerTileType MPerTile_;
NPerTileType NPerTile_; NPerTileType NPerTile_;
KPerTileType KPerTile_; KPerTileType KPerTile_;
......
...@@ -88,6 +88,10 @@ template <typename InputDataType, ...@@ -88,6 +88,10 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{ {
static_assert(KPerBlock == Gemm1NPerBlock);
static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0);
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
...@@ -1213,7 +1217,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1213,7 +1217,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using D0GridDescriptor_M0_N0_M1_M2_N1_M3 = using D0GridDescriptor_M0_N0_M1_M2_N1_M3 =
remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>; remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>;
struct D0Loader struct D0Operator
{ {
template <typename DataType> template <typename DataType>
struct TypeTransform struct TypeTransform
...@@ -1229,17 +1233,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1229,17 +1233,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr index_t Size0 = 0; static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t); static constexpr index_t Size = sizeof(ck::half_t);
}; };
static constexpr index_t NThreadClusterLengths = MPerXdl; static constexpr index_t NThreadClusterLengths = 32;
static_assert(MPerXdl <= KPerBlock); static_assert(NPerXdl == 32);
static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock, static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock"); "D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock");
__host__ __device__ static constexpr auto GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3() __host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3()
{ {
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor_packed( return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, D0M1, Number<NPerBlock>{}, D0M2)); make_tuple(I1, I1, I1, D0M1, Number<NPerBlock>{}, D0M2));
} }
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2() __host__ __device__ static constexpr auto GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2()
{ {
constexpr auto d0_raw_m0_n_m1 = constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<NPerBlock>{}, D0M2)); make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<NPerBlock>{}, D0M2));
...@@ -1254,15 +1257,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1254,15 +1257,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{})); make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2; return d0_n0_n1_m0_m1_m2;
} }
static constexpr auto d0_block_write_desc_m0_n0_m1_m2_n1_m3 = static constexpr auto d0_block_dst_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3(); GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_read_desc_n0_n1_m0_m1_m2 = static constexpr auto d0_block_src_desc_n0_n1_m0_m1_m2 =
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2(); GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_thread_desc_ = static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1< static constexpr auto& d0grad_block_dst_desc_n0_n1_m0_m1_m2 =
d0_block_src_desc_n0_n1_m0_m1_m2;
static constexpr auto& d0grad_block_src_desc_m0_n0_m1_m2_n1_m3 =
d0_block_dst_desc_m0_n0_m1_m2_n1_m3;
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -1273,34 +1281,77 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1273,34 +1281,77 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
1, 1,
BlockSize / NThreadClusterLengths, BlockSize / NThreadClusterLengths,
NThreadClusterLengths, NThreadClusterLengths,
1>, // ThreadClusterLengths 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_write_desc_m0_n0_m1_m2_n1_m3), // DstDesc decltype(d0_block_dst_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
5, // DstVectorDim 5, // DstVectorDim
D0BlockTransferSrcScalarPerVector, // SrcScalarPerVector 4, // SrcScalarPerVector
4, // DstScalarPerVector 4, // DstScalarPerVector
1, 1,
1, 1,
true, true,
true, // DstResetCoord true, // DstResetCoord
1>; 1>;
using D0ThreadWiseCopy = using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc decltype(d0_block_src_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
2, // SrcScalarPerVector 4, // SrcScalarPerVector
2>; 2>;
using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_),
decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim
4, // ScalarPerVector
InMemoryDataOperationEnum::Set, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
using D0GradBlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths
Sequence<1,
1,
1,
BlockSize / NThreadClusterLengths,
NThreadClusterLengths,
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0grad_block_src_desc_m0_n0_m1_m2_n1_m3), // SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc
Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder
Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder
5, // SrcVectorDim
4, // DstVectorDim
4, // SrcScalarPerVector
D0BlockTransferSrcScalarPerVector, // DstScalarPerVector
1,
1,
true,
true, // DstResetCoord
1>;
}; };
struct SharedMemTrait struct SharedMemTrait
...@@ -1335,11 +1386,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1335,11 +1386,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
q_block_space_size_aligned.value; q_block_space_size_aligned.value;
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple( static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align); D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset = static constexpr auto d0_block_space_offset =
(k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value + (k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value) * q_block_space_size_aligned.value) *
sizeof(GemmDataType) / D0Loader::template TypeTransform<D0DataType>::Size; sizeof(GemmDataType) / D0Operator::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
...@@ -1356,7 +1407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1356,7 +1407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
sizeof(GemmDataType); sizeof(GemmDataType);
const index_t d0_bytes_end = const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) * (SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Loader::template TypeTransform<D0DataType>::Size0; D0Operator::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end = const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
...@@ -1379,6 +1430,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1379,6 +1430,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
...@@ -1846,17 +1898,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1846,17 +1898,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// gemm0 M loop // gemm0 M loop
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1; index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0 // D0
auto d0_block_copy_global_to_lds = typename D0Loader::D0BlockwiseCopy( auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopyGlobalToLds(
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0), make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy( auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0)); make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto& d0grad_grid_desc_m0_n0_m1_m2_n1_m3 = d0_grid_desc_m0_n0_m1_m2_n1_m3;
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout});
auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
block_sync_lds(); block_sync_lds();
...@@ -1992,49 +2058,53 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1992,49 +2058,53 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// add bias // add bias
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc(); if(p_d0_grid != nullptr)
{
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc();
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
D0Loader::d0_thread_desc_.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
static_for<0, D0M0, 1>{}([&](auto mr) { D0Operator::d0_thread_desc_.GetElementSpaceSize());
// load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3, static_for<0, D0M0, 1>{}([&](auto mr) {
d0_grid_buf); // load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_grid_buf);
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_block_copy_global_to_lds.RunWrite( d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds(); d0_block_copy_global_to_lds.RunWrite(
// read data form lds D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
d0_thread_copy_lds_to_vgpr.Run(D0Loader::d0_block_read_desc_n0_n1_m0_m1_m2, block_sync_lds();
make_tuple(I0, I0, I0, I0, I0), // read data form lds
d0_block_buf, d0_thread_copy_lds_to_vgpr.Run(D0Operator::d0_block_src_desc_n0_n1_m0_m1_m2,
D0Loader::d0_thread_desc_, make_tuple(I0, I0, I0, I0, I0),
make_tuple(I0, I0, I0, I0, I0), d0_block_buf,
d0_thread_buf); D0Operator::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
// bias add d0_thread_buf);
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset = // bias add
c_thread_desc.CalculateOffset(make_tuple(mr, I0, i)); static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
s_slash_p_thread_buf(Number<c_offset>{}) += c_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
});
}); });
});
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
} }
// P_i: = softmax(scalar * S_i:) // P_i: = softmax(scalar * S_i:)
...@@ -2125,6 +2195,45 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -2125,6 +2195,45 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
: y_dot_ygrad_thread_buf[Number<m>{}]); : y_dot_ygrad_thread_buf[Number<m>{}]);
}); });
// output bias grad
if constexpr(!is_same<D0DataType, void>::value)
{
if(p_d0grad_grid != nullptr)
{
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf,
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
d0grad_block_buf);
block_sync_lds();
// write data from lds to global
d0grad_block_copy_lds_to_global.Run(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
d0grad_block_buf,
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
d0grad_grid_buf,
I0);
d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
});
d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
}
// gemm dV // gemm dV
// dV = P_drop^T * dY // dV = P_drop^T * dY
{ {
......
...@@ -96,6 +96,10 @@ template <typename InputDataType, ...@@ -96,6 +96,10 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{ {
static_assert(Gemm1NPerBlock % KPerBlock == 0);
static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0);
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
...@@ -1292,7 +1296,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1292,7 +1296,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using D0GridDescriptor_M0_N0_M1_M2_N1_M3 = using D0GridDescriptor_M0_N0_M1_M2_N1_M3 =
remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>; remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>;
struct D0Loader struct D0Operator
{ {
template <typename DataType> template <typename DataType>
struct TypeTransform struct TypeTransform
...@@ -1312,13 +1316,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1312,13 +1316,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static_assert(NPerXdl == 32); static_assert(NPerXdl == 32);
static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock, static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock"); "D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock");
__host__ __device__ static constexpr auto GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3() __host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3()
{ {
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor_packed( return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, D0M1, Number<NPerBlock>{}, D0M2)); make_tuple(I1, I1, I1, D0M1, Number<NPerBlock>{}, D0M2));
} }
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2() __host__ __device__ static constexpr auto GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2()
{ {
constexpr auto d0_raw_m0_n_m1 = constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<NPerBlock>{}, D0M2)); make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<NPerBlock>{}, D0M2));
...@@ -1333,15 +1336,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1333,15 +1336,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{})); make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2; return d0_n0_n1_m0_m1_m2;
} }
static constexpr auto d0_block_write_desc_m0_n0_m1_m2_n1_m3 = static constexpr auto d0_block_dst_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3(); GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_read_desc_n0_n1_m0_m1_m2 = static constexpr auto d0_block_src_desc_n0_n1_m0_m1_m2 =
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2(); GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_thread_desc_ = static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1< static constexpr auto& d0grad_block_dst_desc_n0_n1_m0_m1_m2 =
d0_block_src_desc_n0_n1_m0_m1_m2;
static constexpr auto& d0grad_block_src_desc_m0_n0_m1_m2_n1_m3 =
d0_block_dst_desc_m0_n0_m1_m2_n1_m3;
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -1352,34 +1360,77 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1352,34 +1360,77 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
1, 1,
BlockSize / NThreadClusterLengths, BlockSize / NThreadClusterLengths,
NThreadClusterLengths, NThreadClusterLengths,
1>, // ThreadClusterLengths 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_write_desc_m0_n0_m1_m2_n1_m3), // DstDesc decltype(d0_block_dst_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
5, // DstVectorDim 5, // DstVectorDim
D0BlockTransferSrcScalarPerVector, // SrcScalarPerVector 4, // SrcScalarPerVector
4, // DstScalarPerVector 4, // DstScalarPerVector
1, 1,
1, 1,
true, true,
true, // DstResetCoord true, // DstResetCoord
1>; 1>;
using D0ThreadWiseCopy = using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc decltype(d0_block_src_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
2, // SrcScalarPerVector 4, // SrcScalarPerVector
2>; 2>;
using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_),
decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim
4, // ScalarPerVector
InMemoryDataOperationEnum::Set, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
using D0GradBlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths
Sequence<1,
1,
1,
BlockSize / NThreadClusterLengths,
NThreadClusterLengths,
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0grad_block_src_desc_m0_n0_m1_m2_n1_m3), // SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc
Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder
Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder
5, // SrcVectorDim
4, // DstVectorDim
4, // SrcScalarPerVector
D0BlockTransferSrcScalarPerVector, // DstScalarPerVector
1,
1,
true,
true, // DstResetCoord
1>;
}; };
struct SharedMemTrait struct SharedMemTrait
...@@ -1414,10 +1465,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1414,10 +1465,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
math::integer_least_multiple(BlockSize, max_lds_align); math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple( static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align); D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset = static constexpr auto d0_block_space_offset =
k_block_space_size_aligned.value * sizeof(GemmDataType) / k_block_space_size_aligned.value * sizeof(GemmDataType) /
D0Loader::template TypeTransform<D0DataType>::Size; D0Operator::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
...@@ -1442,7 +1493,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1442,7 +1493,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
sizeof(GemmDataType); sizeof(GemmDataType);
const index_t d0_bytes_end = const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) * (SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Loader::template TypeTransform<D0DataType>::Size0; D0Operator::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end = const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
...@@ -1470,6 +1521,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1470,6 +1521,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
...@@ -1967,17 +2019,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1967,17 +2019,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1; index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0 // D0
auto d0_block_copy_global_to_lds = typename D0Loader::D0BlockwiseCopy( auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopyGlobalToLds(
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0), make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy( auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0)); make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto& d0grad_grid_desc_m0_n0_m1_m2_n1_m3 = d0_grid_desc_m0_n0_m1_m2_n1_m3;
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout});
auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
block_sync_lds(); block_sync_lds();
...@@ -2143,50 +2209,53 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -2143,50 +2209,53 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// add bias // add bias
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc(); if(p_d0_grid != nullptr)
{
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc();
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Loader::d0_thread_desc_.GetElementSpaceSize()); D0Operator::d0_thread_desc_.GetElementSpaceSize());
ignore = d0_thread_buf;
static_for<0, D0M0, 1>{}([&](auto mr) { static_for<0, D0M0, 1>{}([&](auto mr) {
// load data to lds // load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_grid_buf); d0_grid_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite( d0_block_copy_global_to_lds.RunWrite(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3, d0_block_buf); D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds(); block_sync_lds();
// read data form lds // read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Loader::d0_block_read_desc_n0_n1_m0_m1_m2, d0_thread_copy_lds_to_vgpr.Run(D0Operator::d0_block_src_desc_n0_n1_m0_m1_m2,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
d0_block_buf, d0_block_buf,
D0Loader::d0_thread_desc_, D0Operator::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf); d0_thread_buf);
// bias add // bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) { static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(mr, I0, i)); c_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
s_slash_p_thread_buf(Number<c_offset>{}) += s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]); ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
});
}); });
});
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
} }
// P_i: = softmax(scalar * S_i:) // P_i: = softmax(scalar * S_i:)
...@@ -2393,6 +2462,46 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -2393,6 +2462,46 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
: y_dot_ygrad_thread_buf[Number<m>{}]); : y_dot_ygrad_thread_buf[Number<m>{}]);
}); });
// output bias grad
if constexpr(!is_same<D0DataType, void>::value)
{
if(p_d0grad_grid != nullptr)
{
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf,
D0Operator::d0_block_src_desc_n0_n1_m0_m1_m2,
d0grad_block_buf);
block_sync_lds();
// write data from lds to global
d0grad_block_copy_lds_to_global.Run(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
d0grad_block_buf,
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
d0grad_grid_buf,
I0);
d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
});
d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
}
SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0], SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
s_blockwise_gemm.GetWaveIdx()[I1]); s_blockwise_gemm.GetWaveIdx()[I1]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment