Commit b15eecba authored by letaoqin's avatar letaoqin
Browse files

Merge branch 'mha-train-develop' of...

Merge branch 'mha-train-develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into mha-train-develop
parents 87d1e073 04c206da
......@@ -513,8 +513,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
nullptr, // p_acc0_bias;
nullptr, // p_acc1_bias;
nullptr,
nullptr,
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
......@@ -558,8 +560,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
nullptr, // p_acc0_bias;
nullptr, // p_acc1_bias;
nullptr,
nullptr,
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
......
......@@ -518,8 +518,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
nullptr, // p_acc0_bias;
nullptr, // p_acc1_bias;
nullptr,
nullptr,
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
......@@ -564,8 +566,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
nullptr, // p_acc0_bias;
nullptr, // p_acc1_bias;
nullptr,
nullptr,
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
......
......@@ -597,8 +597,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // p_acc0_biases;
{}, // p_acc1_biases;
nullptr, // p_acc0_biases;
nullptr, // p_acc1_biases;
nullptr,
nullptr,
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
......@@ -686,8 +688,8 @@ int run(int argc, char* argv[])
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_fwd_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
nullptr, // p_acc0_bias;
nullptr, // p_acc1_bias;
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
......@@ -743,8 +745,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
nullptr, // p_acc0_bias;
nullptr, // p_acc1_bias;
nullptr,
nullptr,
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
......
......@@ -604,6 +604,8 @@ int run(int argc, char* argv[])
p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs,
QKVElementOp{},
QKVElementOp{},
......@@ -650,6 +652,8 @@ int run(int argc, char* argv[])
p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs,
QKVElementOp{},
QKVElementOp{},
......
......@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8.
#define DIM 32 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......@@ -616,6 +616,8 @@ int run(int argc, char* argv[])
p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs,
QKVElementOp{},
QKVElementOp{},
......@@ -663,6 +665,8 @@ int run(int argc, char* argv[])
p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs,
QKVElementOp{},
QKVElementOp{},
......
......@@ -728,6 +728,8 @@ int run(int argc, char* argv[])
p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs_bwd,
QKVElementOp{},
QKVElementOp{},
......@@ -815,6 +817,8 @@ int run(int argc, char* argv[])
p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs_bwd,
QKVElementOp{},
QKVElementOp{},
......
......@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8.
#define DIM 128 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......@@ -57,6 +57,7 @@ using BF16 = ck::bhalf_t;
using F32 = float;
using U16 = unsigned short;
using INT32 = int32_t;
using U8 = uint8_t;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
......@@ -374,8 +375,8 @@ int run(int argc, char* argv[])
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d_gs_ms_ns_strides =
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // D layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // D layout [G0, G1, M, N]
......@@ -396,7 +397,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<Acc0BiasDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
Tensor<Acc0BiasDataType> d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides);
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<InputDataType> v_gs_os_ns(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
......@@ -405,7 +406,7 @@ int run(int argc, char* argv[])
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl;
std::cout << "d0_gs_ms_ns: " << d0_gs_ms_ns.mDesc << std::endl;
std::cout << "z_gs_ms_ns: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
......@@ -420,36 +421,35 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-2, 2});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-2, 2});
// d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2<Acc0BiasDataType>{-2, 2});
break;
case 2:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_3<InputDataType>{-0.5, 0.5});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{-0.5, 0.5});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3<Acc0BiasDataType>{-0.5, 0.5});
break;
case 3:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2<InputDataType>{-5, 5});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break;
case 4:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
break;
case 5:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// dO dot O = [0; 1; 2; ...]
break;
case 6:
......@@ -457,7 +457,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
......@@ -471,7 +471,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0,g1, m, o]
d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
......@@ -485,7 +485,7 @@ int run(int argc, char* argv[])
// qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem k_device_buf(sizeof(InputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(Acc0BiasDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(Acc0BiasDataType) * d0_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem z_device_buf(sizeof(ZDataType) * z_gs_ms_ns.mDesc.GetElementSpaceSize());
DeviceMem v_device_buf(sizeof(InputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
......@@ -494,12 +494,14 @@ int run(int argc, char* argv[])
DeviceMem kgrad_device_buf(sizeof(OutputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(OutputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem d0grad_device_buf(sizeof(Acc0BiasDataType) * d0_gs_ms_ns.mDesc.GetElementSpaceSize());
q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
d_device_buf.ToDevice(d_gs_ms_ns.mData.data());
d0_device_buf.ToDevice(d0_gs_ms_ns.mData.data());
v_device_buf.ToDevice(v_gs_os_ns.mData.data());
ygrad_device_buf.ToDevice(ygrad_gs_ms_os.mData.data());
z_device_buf.ToDevice(z_gs_ms_ns.mData.data());
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
......@@ -516,8 +518,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_bias;
static_cast<Acc0BiasDataType*>(d0_device_buf.GetDeviceBuffer()), // p_acc0_bias;
nullptr, // p_acc1_bias;
static_cast<Acc0BiasDataType*>(d0grad_device_buf.GetDeviceBuffer()),
nullptr,
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
......@@ -529,8 +533,8 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
d_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_bias_gs_ms_os_lengths,
{}, // acc1_bias_gs_ms_os_strides,
QKVElementOp{},
......@@ -561,8 +565,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_bias;
static_cast<Acc0BiasDataType*>(d0_device_buf.GetDeviceBuffer()), // p_acc0_bias;
nullptr, // p_acc1_bias;
static_cast<Acc0BiasDataType*>(d0grad_device_buf.GetDeviceBuffer()),
nullptr,
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
......@@ -574,8 +580,8 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths,
y_gs_ms_os_strides,
lse_gs_ms_lengths,
d_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_bias_gs_ms_os_lengths,
{}, // acc1_bias_gs_ms_os_strides,
QKVElementOp{},
......@@ -599,7 +605,7 @@ int run(int argc, char* argv[])
(sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) +
sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N +
sizeof(OutputDataType) * N * O + sizeof(Acc0BiasDataType) * M * N) *
sizeof(OutputDataType) * N * O + sizeof(Acc0BiasDataType) * M * N * size_t(2)) *
BatchCount +
sizeof(LSEDataType) * M * BatchCount;
......@@ -618,7 +624,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N});
Tensor<Acc0BiasDataType> d0_g_m_n({G0 * G1, M, N});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
......@@ -640,13 +646,13 @@ int run(int argc, char* argv[])
v_gs_os_ns.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
});
d_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
// run fwd again for y, cause z_g_m_n update
run_attention_fwd_host(q_g_m_k,
k_g_n_k,
d_g_m_n,
d0_g_m_n,
v_g_n_o,
alpha,
s_g_m_n,
......@@ -783,14 +789,19 @@ int run(int argc, char* argv[])
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_host_result(d0_gs_ms_ns_lengths,
d0_gs_ms_ns_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides);
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_device_result(d0_gs_ms_ns_lengths,
d0_gs_ms_ns_strides);
qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data());
kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data());
vgrad_device_buf.FromDevice(vgrad_gs_os_ns_device_result.mData.data());
d0grad_device_buf.FromDevice(d0grad_gs_ms_ns_device_result.mData.data());
// permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
......@@ -818,6 +829,15 @@ int run(int argc, char* argv[])
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
});
d0grad_gs_ms_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = sgrad_g_m_n(g, idx[2], idx[3]);
});
std::cout << "Checking qgrad:\n";
pass &= ck::utils::check_err(qgrad_gs_ms_ks_device_result.mData,
qgrad_gs_ms_ks_host_result.mData,
......@@ -836,6 +856,12 @@ int run(int argc, char* argv[])
"error",
1e-2,
1e-2);
std::cout << "Checking d0grad:\n";
pass &= ck::utils::check_err(d0grad_gs_ms_ns_device_result.mData,
d0grad_gs_ms_ns_host_result.mData,
"error",
1e-2,
1e-2);
}
return pass ? ((void)(std::cout << "pass\n"), 0) : ((void)(std::cout << "fail\n"), 1);
......
......@@ -333,6 +333,7 @@ int run(int argc, char* argv[])
std::vector<const void*> p_lse;
std::vector<void*> p_qgrad;
std::vector<void*> p_kgrad;
std::vector<void*> p_d0grad;
std::vector<void*> p_vgrad;
std::vector<const void*> p_ygrad;
......@@ -356,6 +357,7 @@ int run(int argc, char* argv[])
std::vector<Tensor<LSEDataType>> lse_tensors;
std::vector<Tensor<OutputDataType>> qgrad_tensors;
std::vector<Tensor<OutputDataType>> kgrad_tensors;
std::vector<Tensor<Acc0BiasDataType>> d0grad_tensors;
std::vector<Tensor<OutputDataType>> vgrad_tensors;
std::vector<Tensor<InputDataType>> ygrad_tensors;
......@@ -369,6 +371,7 @@ int run(int argc, char* argv[])
std::vector<DeviceMemPtr> qgrad_tensors_device;
std::vector<DeviceMemPtr> ygrad_tensors_device;
std::vector<DeviceMemPtr> kgrad_tensors_device;
std::vector<DeviceMemPtr> d0grad_tensors_device;
std::vector<DeviceMemPtr> vgrad_tensors_device;
std::size_t group_count = 10;
std::size_t flop = 0, num_byte = 0;
......@@ -445,10 +448,11 @@ int run(int argc, char* argv[])
int BatchCount = G0 * G1;
flop += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte += (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
num_byte +=
(sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) +
sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N +
sizeof(OutputDataType) * N * O + sizeof(Acc0BiasDataType) * M * N) *
sizeof(OutputDataType) * N * O + sizeof(Acc0BiasDataType) * M * N * size_t(2)) *
BatchCount +
sizeof(LSEDataType) * M * BatchCount;
......@@ -600,6 +604,8 @@ int run(int argc, char* argv[])
std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * k_gs_ns_ks.GetElementSpaceSize()));
d0grad_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(Acc0BiasDataType) * d0_gs_ms_ns.GetElementSpaceSize()));
vgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * v_gs_os_ns.GetElementSpaceSize()));
ygrad_tensors_device.emplace_back(
......@@ -619,6 +625,7 @@ int run(int argc, char* argv[])
p_y.push_back(y_tensors_device.back()->GetDeviceBuffer());
p_lse.push_back(lse_tensors_device.back()->GetDeviceBuffer());
p_kgrad.push_back(kgrad_tensors_device.back()->GetDeviceBuffer());
p_d0grad.push_back(d0grad_tensors_device.back()->GetDeviceBuffer());
p_vgrad.push_back(vgrad_tensors_device.back()->GetDeviceBuffer());
p_ygrad.push_back(ygrad_tensors_device.back()->GetDeviceBuffer());
p_qgrad.push_back(qgrad_tensors_device.back()->GetDeviceBuffer());
......@@ -636,6 +643,8 @@ int run(int argc, char* argv[])
p_vgrad,
p_d0,
{},
p_d0grad,
{},
problem_descs,
QKVElementOp{},
QKVElementOp{},
......@@ -682,6 +691,8 @@ int run(int argc, char* argv[])
p_vgrad,
p_d0,
{},
p_d0grad,
{},
problem_descs,
QKVElementOp{},
QKVElementOp{},
......@@ -732,6 +743,7 @@ int run(int argc, char* argv[])
lse_tensors_device[i]->ToDevice(lse_tensors[i].data());
qgrad_tensors_device[i]->SetZero();
kgrad_tensors_device[i]->SetZero();
d0grad_tensors_device[i]->SetZero();
vgrad_tensors_device[i]->SetZero();
}
......@@ -804,6 +816,8 @@ int run(int argc, char* argv[])
q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides());
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_host_result(d0_tensors[i].GetLengths(),
d0_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(),
v_tensors[i].GetStrides());
......@@ -811,11 +825,14 @@ int run(int argc, char* argv[])
q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides());
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_device_result(d0_tensors[i].GetLengths(),
d0_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(),
v_tensors[i].GetStrides());
qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data());
kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data());
d0grad_tensors_device[i]->FromDevice(d0grad_gs_ms_ns_device_result.data());
vgrad_tensors_device[i]->FromDevice(vgrad_gs_os_ns_device_result.data());
// permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
......@@ -834,6 +851,14 @@ int run(int argc, char* argv[])
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
});
d0grad_gs_ms_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = sgrad_g_m_n(g, idx[2], idx[3]);
});
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
......@@ -861,6 +886,12 @@ int run(int argc, char* argv[])
"error",
1e-2,
1e-2);
std::cout << "Checking d0grad:\n";
pass &= ck::utils::check_err(d0grad_gs_ms_ns_device_result.mData,
d0grad_gs_ms_ns_host_result.mData,
"error",
1e-2,
1e-2);
}
}
......
......@@ -71,6 +71,7 @@ __global__ void
ignore = p_z_grid;
ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
ignore = block_2_ctile_map;
ignore = num_gemm0_m_block_outer_loop;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
ignore = seed;
......@@ -135,7 +136,7 @@ struct DeviceBatchedDropout : public ck::tensor_operation::device::BaseOperator
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_m_n_lengths,
const std::vector<index_t>& z_gs_m_n_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_m_n_lengths, z_gs_m_n_strides);
return Transform::MakeC0GridDescriptor_M_N(z_gs_m_n_lengths, z_gs_m_n_strides);
}
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
......
......@@ -123,6 +123,7 @@ __global__ void
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
......@@ -176,12 +177,20 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
if(p_d0_grid != nullptr)
{
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
if(p_d0grad_grid != nullptr)
{
tmp_p_d0grad_grid = p_d0grad_grid + d0_batch_offset;
}
}
if constexpr(Deterministic)
{
for(index_t i = 0; i < nblock; i++)
......@@ -197,6 +206,7 @@ __global__ void
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
......@@ -233,6 +243,7 @@ __global__ void
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
......@@ -266,6 +277,7 @@ __global__ void
ignore = p_ygrad_grid;
ignore = p_qgrad_grid;
ignore = p_kgrad_grid;
ignore = p_d0grad_grid;
ignore = p_vgrad_grid;
ignore = a_element_op;
ignore = b_element_op;
......@@ -579,32 +591,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
//
// dQ = alpha * dS * K
//
// QGrad in Gemm C position
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths,
const std::vector<index_t>& q_gs_ms_ks_strides)
{
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
}
//
// dK = alpha * dS^T * Q
//
// KGrad in Gemm C position
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths,
const std::vector<index_t>& k_gs_ns_ks_strides)
{
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw)
......@@ -635,7 +622,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
}
static auto MakeDGridDescriptor_M(index_t MRaw)
......@@ -665,7 +652,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
......@@ -673,7 +660,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
......@@ -858,6 +845,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -894,6 +883,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_qgrad_grid_{p_qgrad_grid},
p_kgrad_grid_{p_kgrad_grid},
p_vgrad_grid_{p_vgrad_grid},
p_d0grad_grid_{p_d0grad_grid},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{
......@@ -921,7 +911,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
d_block_2_ctile_map_{
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(d_y_grid_desc_m_o_)},
......@@ -948,10 +938,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_drop_{p_drop}
{
// TODO: implement bias addition
ignore = p_acc0_bias;
ignore = p_d1grad_grid;
ignore = p_acc1_bias;
ignore = acc0_bias_gs_ms_ns_lengths;
ignore = acc0_bias_gs_ms_ns_strides;
ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_bias_gs_ms_gemm1ns_strides;
......@@ -962,7 +950,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
d0_grid_desc_g_m_n_ = Transform::MakeC0GridDescriptor_G_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]);
......@@ -1030,6 +1018,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_;
D0DataType* p_d0grad_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
......@@ -1191,6 +1180,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
arg.p_ygrad_grid_,
arg.p_qgrad_grid_,
arg.p_kgrad_grid_,
arg.p_d0grad_grid_,
arg.p_vgrad_grid_,
arg.a_element_op_,
arg.b_element_op_,
......@@ -1342,6 +1332,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1380,6 +1372,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_vgrad_grid,
p_acc0_bias,
p_acc1_bias,
p_d0grad_grid,
p_d1grad_grid,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......@@ -1422,6 +1416,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
void* p_vgrad_grid,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
void* p_d0grad_grid,
void* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1461,6 +1457,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<D0DataType*>(p_d0grad_grid),
static_cast<D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......
......@@ -123,6 +123,7 @@ __global__ void
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
......@@ -176,12 +177,20 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
if(p_d0_grid != nullptr)
{
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
if(p_d0grad_grid != nullptr)
{
tmp_p_d0grad_grid = p_d0grad_grid + d0_batch_offset;
}
}
if constexpr(Deterministic)
{
......@@ -198,6 +207,7 @@ __global__ void
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
......@@ -234,6 +244,7 @@ __global__ void
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
......@@ -267,6 +278,7 @@ __global__ void
ignore = p_ygrad_grid;
ignore = p_qgrad_grid;
ignore = p_kgrad_grid;
ignore = p_d0grad_grid;
ignore = p_vgrad_grid;
ignore = a_element_op;
ignore = b_element_op;
......@@ -587,39 +599,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
}
// Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
//
// dQ = alpha * dS * K
//
// QGrad in Gemm C position
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths,
const std::vector<index_t>& q_gs_ms_ks_strides)
{
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
}
//
// dK = alpha * dS^T * Q
//
// KGrad in Gemm C position
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths,
const std::vector<index_t>& k_gs_ns_ks_strides)
{
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw)
......@@ -674,7 +661,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
......@@ -682,7 +669,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
......@@ -874,6 +861,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -910,6 +899,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_qgrad_grid_{p_qgrad_grid},
p_kgrad_grid_{p_kgrad_grid},
p_vgrad_grid_{p_vgrad_grid},
p_d0grad_grid_{p_d0grad_grid},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{
......@@ -936,7 +926,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
d_block_2_ctile_map_{
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(d_y_grid_desc_m_o_)},
......@@ -964,6 +954,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
// TODO: implement bias addition
ignore = p_acc1_bias;
ignore = p_d1grad_grid;
ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_bias_gs_ms_gemm1ns_strides;
......@@ -974,7 +965,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
d0_grid_desc_g_m_n_ = Transform::MakeC0GridDescriptor_G_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]);
......@@ -1042,6 +1033,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_;
D0DataType* p_d0grad_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
......@@ -1207,6 +1199,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
arg.p_ygrad_grid_,
arg.p_qgrad_grid_,
arg.p_kgrad_grid_,
arg.p_d0grad_grid_,
arg.p_vgrad_grid_,
arg.a_element_op_,
arg.b_element_op_,
......@@ -1374,6 +1367,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1412,6 +1407,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_vgrad_grid,
p_acc0_bias,
p_acc1_bias,
p_d0grad_grid,
p_d1grad_grid,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......@@ -1454,6 +1451,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
void* p_vgrad_grid,
const void* p_acc0_bias,
const void* p_acc1_bias,
void* p_d0grad_grid,
void* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1493,6 +1492,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<D0DataType*>(p_d0grad_grid),
static_cast<D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......
......@@ -65,6 +65,7 @@ __global__ void
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
......@@ -120,12 +121,20 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
if(p_d0_grid != nullptr)
{
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
if(p_d0grad_grid != nullptr)
{
tmp_p_d0grad_grid = p_d0grad_grid + d0_batch_offset;
}
}
if constexpr(Deterministic)
{
for(index_t i = 0; i < nblock; i++)
......@@ -141,6 +150,7 @@ __global__ void
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
......@@ -178,6 +188,7 @@ __global__ void
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
......@@ -212,6 +223,7 @@ __global__ void
ignore = p_ygrad_grid;
ignore = p_qgrad_grid;
ignore = p_kgrad_grid;
ignore = p_d0grad_grid;
ignore = p_vgrad_grid;
ignore = a_element_op;
ignore = b_element_op;
......@@ -514,32 +526,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
//
// dQ = alpha * dS * K
//
// QGrad in Gemm C position
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths,
const std::vector<index_t>& q_gs_ms_ks_strides)
{
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
}
//
// dK = alpha * dS^T * Q
//
// KGrad in Gemm C position
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths,
const std::vector<index_t>& k_gs_ns_ks_strides)
{
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw)
......@@ -570,12 +557,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
......@@ -583,7 +570,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
......@@ -755,6 +742,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -790,6 +779,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_qgrad_grid_{p_qgrad_grid},
p_kgrad_grid_{p_kgrad_grid},
p_vgrad_grid_{p_vgrad_grid},
p_d0grad_grid_{p_d0grad_grid},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{
......@@ -814,7 +804,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
y_grid_desc_mblock_mperblock_oblock_operblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
a_element_op_{a_element_op},
......@@ -839,10 +829,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_drop_{p_drop}
{
// TODO: implement bias addition
ignore = p_acc0_bias;
ignore = p_d1grad_grid;
ignore = p_acc1_bias;
ignore = acc0_bias_gs_ms_ns_lengths;
ignore = acc0_bias_gs_ms_ns_strides;
ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_bias_gs_ms_gemm1ns_strides;
......@@ -862,7 +850,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
d0_grid_desc_g_m_n_ = Transform::MakeC0GridDescriptor_G_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]);
......@@ -926,6 +914,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_;
D0DataType* p_d0grad_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
......@@ -1049,6 +1038,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
arg.p_ygrad_grid_,
arg.p_qgrad_grid_,
arg.p_kgrad_grid_,
arg.p_d0grad_grid_,
arg.p_vgrad_grid_,
arg.a_element_op_,
arg.b_element_op_,
......@@ -1200,6 +1190,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1237,6 +1229,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_vgrad_grid,
p_acc0_bias,
p_acc1_bias,
p_d0grad_grid,
p_d1grad_grid,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......@@ -1278,6 +1272,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
void* p_vgrad_grid,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1316,6 +1312,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<const D0DataType*>(p_d0grad_grid),
static_cast<const D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......
......@@ -65,6 +65,7 @@ __global__ void
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
......@@ -120,13 +121,21 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
if(p_d0_grid != nullptr)
{
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
if(p_d0grad_grid != nullptr)
{
tmp_p_d0grad_grid = p_d0grad_grid + d0_batch_offset;
}
}
if constexpr(Deterministic)
{
for(index_t i = 0; i < nblock; i++)
......@@ -142,6 +151,7 @@ __global__ void
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
......@@ -179,6 +189,7 @@ __global__ void
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
......@@ -213,6 +224,7 @@ __global__ void
ignore = p_ygrad_grid;
ignore = p_qgrad_grid;
ignore = p_kgrad_grid;
ignore = p_d0grad_grid;
ignore = p_vgrad_grid;
ignore = a_element_op;
ignore = b_element_op;
......@@ -522,39 +534,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
}
// Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
//
// dQ = alpha * dS * K
//
// QGrad in Gemm C position
static auto MakeQGradGridDescriptor_M_K(const std::vector<index_t>& q_gs_ms_ks_lengths,
const std::vector<index_t>& q_gs_ms_ks_strides)
{
return Transform::MakeCGridDescriptor_M_N(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
}
//
// dK = alpha * dS^T * Q
//
// KGrad in Gemm C position
static auto MakeKGradGridDescriptor_N_K(const std::vector<index_t>& k_gs_ns_ks_lengths,
const std::vector<index_t>& k_gs_ns_ks_strides)
{
return Transform::MakeCGridDescriptor_M_N(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw)
......@@ -584,7 +571,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
......@@ -592,7 +579,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
......@@ -771,6 +758,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -806,6 +795,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_qgrad_grid_{p_qgrad_grid},
p_kgrad_grid_{p_kgrad_grid},
p_vgrad_grid_{p_vgrad_grid},
p_d0grad_grid_{p_d0grad_grid},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{
......@@ -829,7 +819,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
y_grid_desc_mblock_mperblock_oblock_operblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
a_element_op_{a_element_op},
......@@ -855,6 +845,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
// TODO: implement bias addition
ignore = p_acc1_bias;
ignore = p_d1grad_grid;
ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_bias_gs_ms_gemm1ns_strides;
......@@ -875,7 +866,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
d0_grid_desc_g_m_n_ = Transform::MakeC0GridDescriptor_G_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]);
......@@ -925,6 +916,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0)
<< ", " << ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", "
<< ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n';
std::cout << "d0_grid_desc_g_m_n_: " << d0_grid_desc_g_m_n_.GetLength(I0) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I1) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I2) << '\n';
}
// pointers
......@@ -939,6 +934,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_;
D0DataType* p_d0grad_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
......@@ -1066,6 +1062,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
arg.p_ygrad_grid_,
arg.p_qgrad_grid_,
arg.p_kgrad_grid_,
arg.p_d0grad_grid_,
arg.p_vgrad_grid_,
arg.a_element_op_,
arg.b_element_op_,
......@@ -1233,6 +1230,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1270,6 +1269,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_vgrad_grid,
p_acc0_bias,
p_acc1_bias,
p_d0grad_grid,
p_d1grad_grid,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......@@ -1311,6 +1312,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
void* p_vgrad_grid,
const void* p_acc0_bias,
const void* p_acc1_bias,
void* p_d0grad_grid,
void* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1349,6 +1352,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<D0DataType*>(p_d0grad_grid),
static_cast<D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......
......@@ -162,13 +162,16 @@ __global__ void
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
if(arg_ptr[group_id].p_d0_grid_ != nullptr)
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
if(arg_ptr[group_id].p_d0grad_grid_)
tmp_p_d0grad_grid = arg_ptr[group_id].p_d0grad_grid_ + d0_batch_offset;
}
if constexpr(Deterministic)
{
......@@ -185,6 +188,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared,
a_element_op,
......@@ -222,6 +226,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared,
a_element_op,
......@@ -540,7 +545,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw)
......@@ -572,7 +577,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -581,7 +586,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -806,6 +811,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
D0DataType* p_d0grad_grid_;
OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy
......@@ -878,6 +884,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -911,7 +919,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size()))
0 == p_acc1_bias_vec.size() &&
(group_count_ == ck::type_convert<ck::index_t>(p_d0grads.size()) ||
ck::type_convert<ck::index_t>(p_d0grads.size() == 0)) &&
0 == p_d1grads.size()))
{
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
}
......@@ -937,6 +948,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_d0grad_grid =
(ck::type_convert<ck::index_t>(p_d0grads.size()) == group_count_)
? static_cast<D0DataType*>(p_d0grads[i])
: nullptr;
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i];
......@@ -983,7 +998,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
const auto z_grid_desc_g_m_n = Transform::MakeC0GridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
......@@ -1054,6 +1069,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_ygrad_grid,
p_qgrad_grid,
p_kgrad_grid,
p_d0grad_grid,
p_vgrad_grid,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
......@@ -1370,6 +1386,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1392,6 +1410,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_Vgrads,
p_acc0_bias_vec,
p_acc1_bias_vec,
p_d0grads,
p_d1grads,
problem_desc_vec,
a_element_op,
b_element_op,
......@@ -1420,6 +1440,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1442,6 +1464,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_Vgrads,
p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias_vec, // cast in struct Argument
p_d0grads,
p_d1grads,
problem_desc_vec,
a_element_op,
b_element_op,
......
......@@ -160,13 +160,17 @@ __global__ void
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
if(arg_ptr[group_id].p_d0_grid_ != nullptr)
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
if(arg_ptr[group_id].p_d0grad_grid_)
tmp_p_d0grad_grid = arg_ptr[group_id].p_d0grad_grid_ + d0_batch_offset;
}
if constexpr(Deterministic)
......@@ -184,6 +188,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared,
a_element_op,
......@@ -221,6 +226,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared,
a_element_op,
......@@ -602,7 +608,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw)
......@@ -634,7 +640,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -643,7 +649,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -682,7 +688,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using D0GridDesc_G_M_N = decltype(MakeD0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
......@@ -876,6 +882,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
D0DataType* p_d0grad_grid_;
OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy
......@@ -948,6 +955,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -981,7 +990,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size()))
0 == p_acc1_bias_vec.size() &&
(group_count_ == ck::type_convert<ck::index_t>(p_d0grads.size()) ||
ck::type_convert<ck::index_t>(p_d0grads.size() == 0)) &&
0 == p_d1grads.size()))
{
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
}
......@@ -1007,6 +1019,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_d0grad_grid =
(ck::type_convert<ck::index_t>(p_d0grads.size()) == group_count_)
? static_cast<D0DataType*>(p_d0grads[i])
: nullptr;
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i];
......@@ -1053,7 +1069,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
const auto z_grid_desc_g_m_n = Transform::MakeC0GridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
......@@ -1124,6 +1140,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_ygrad_grid,
p_qgrad_grid,
p_kgrad_grid,
p_d0grad_grid,
p_vgrad_grid,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
......@@ -1445,6 +1462,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1467,6 +1486,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_Vgrads,
p_acc0_bias_vec,
p_acc1_bias_vec,
p_d0grads,
p_d1grads,
problem_desc_vec,
a_element_op,
b_element_op,
......@@ -1495,6 +1516,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1517,6 +1540,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_Vgrads,
p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias_vec, // cast in struct Argument
p_d0grads,
p_d1grads,
problem_desc_vec,
a_element_op,
b_element_op,
......
......@@ -103,13 +103,17 @@ __global__ void
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
if(arg_ptr[group_id].p_d0_grid_ != nullptr)
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
if(arg_ptr[group_id].p_d0grad_grid_)
tmp_p_d0grad_grid = arg_ptr[group_id].p_d0grad_grid_ + d0_batch_offset;
}
if constexpr(Deterministic)
{
......@@ -126,6 +130,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared,
a_element_op,
......@@ -164,6 +169,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared,
a_element_op,
......@@ -471,7 +477,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw)
......@@ -503,7 +509,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -512,7 +518,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -526,7 +532,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using D0GridDesc_G_M_N = decltype(MakeD0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
......@@ -696,6 +702,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
D0DataType* p_d0grad_grid_;
OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy
......@@ -760,6 +767,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -792,7 +801,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size()))
0 == p_acc1_bias_vec.size() &&
(group_count_ == ck::type_convert<ck::index_t>(p_d0grads.size()) ||
ck::type_convert<ck::index_t>(p_d0grads.size() == 0)) &&
0 == p_d1grads.size()))
{
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
}
......@@ -816,6 +828,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_d0grad_grid =
(ck::type_convert<ck::index_t>(p_d0grads.size()) == group_count_)
? static_cast<D0DataType*>(p_d0grads[i])
: nullptr;
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i];
......@@ -862,7 +878,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
const auto z_grid_desc_g_m_n = Transform::MakeC0GridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
......@@ -925,6 +941,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_ygrad_grid,
p_qgrad_grid,
p_kgrad_grid,
p_d0grad_grid,
p_vgrad_grid,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
......@@ -1214,6 +1231,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1235,6 +1254,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Vgrads,
p_acc0_bias_vec,
p_acc1_bias_vec,
p_d0grads,
p_d1grads,
problem_desc_vec,
a_element_op,
b_element_op,
......@@ -1262,6 +1283,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1283,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Vgrads,
p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias_vec, // cast in struct Argument
p_d0grads,
p_d1grads,
problem_desc_vec,
a_element_op,
b_element_op,
......
......@@ -102,13 +102,16 @@ __global__ void
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
if(arg_ptr[group_id].p_d0_grid_ != nullptr)
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
if(arg_ptr[group_id].p_d0grad_grid_)
tmp_p_d0grad_grid = arg_ptr[group_id].p_d0grad_grid_ + d0_batch_offset;
}
if constexpr(Deterministic)
......@@ -126,6 +129,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared,
a_element_op,
......@@ -164,6 +168,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared,
a_element_op,
......@@ -534,7 +539,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw)
......@@ -566,7 +571,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -575,7 +580,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -589,7 +594,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using D0GridDesc_G_M_N = decltype(MakeD0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
......@@ -767,6 +772,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
D0DataType* p_d0grad_grid_;
OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy
......@@ -831,6 +837,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -863,7 +871,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size()))
0 == p_acc1_bias_vec.size() &&
(group_count_ == ck::type_convert<ck::index_t>(p_d0grads.size()) ||
ck::type_convert<ck::index_t>(p_d0grads.size() == 0)) &&
0 == p_d1grads.size()))
{
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
}
......@@ -887,6 +898,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_d0grad_grid =
(ck::type_convert<ck::index_t>(p_d0grads.size()) == group_count_)
? static_cast<D0DataType*>(p_d0grads[i])
: nullptr;
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i];
......@@ -933,7 +948,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
const auto z_grid_desc_g_m_n = Transform::MakeC0GridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
......@@ -996,6 +1011,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_ygrad_grid,
p_qgrad_grid,
p_kgrad_grid,
p_d0grad_grid,
p_vgrad_grid,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
......@@ -1290,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1311,6 +1329,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Vgrads,
p_acc0_bias_vec,
p_acc1_bias_vec,
p_d0grads,
p_d1grads,
problem_desc_vec,
a_element_op,
b_element_op,
......@@ -1338,6 +1358,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1359,6 +1381,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Vgrads,
p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias_vec, // cast in struct Argument
p_d0grads,
p_d1grads,
problem_desc_vec,
a_element_op,
b_element_op,
......
......@@ -119,6 +119,15 @@ struct GemmGemmPadder
c_desc_mraw_nraw, make_tuple(MPerTile_, OPerTile_), Sequence<PadM, PadO>{});
}
// C[M, Gemm1N] = C[M, N]
template <typename C0Desc_MRaw_NRaw>
__host__ __device__ constexpr auto
PadC0Descriptor_M_N(const C0Desc_MRaw_NRaw& c_desc_mraw_nraw) const
{
return PadTensorDescriptor(
c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
}
MPerTileType MPerTile_;
NPerTileType NPerTile_;
KPerTileType KPerTile_;
......
......@@ -88,6 +88,10 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
{
static_assert(KPerBlock == Gemm1NPerBlock);
static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0);
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......@@ -1213,7 +1217,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using D0GridDescriptor_M0_N0_M1_M2_N1_M3 =
remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>;
struct D0Loader
struct D0Operator
{
template <typename DataType>
struct TypeTransform
......@@ -1229,17 +1233,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static constexpr index_t Size0 = 0;
static constexpr index_t Size = sizeof(ck::half_t);
};
static constexpr index_t NThreadClusterLengths = MPerXdl;
static_assert(MPerXdl <= KPerBlock);
static constexpr index_t NThreadClusterLengths = 32;
static_assert(NPerXdl == 32);
static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock");
__host__ __device__ static constexpr auto GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3()
__host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, D0M1, Number<NPerBlock>{}, D0M2));
}
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2()
__host__ __device__ static constexpr auto GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2()
{
constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<NPerBlock>{}, D0M2));
......@@ -1254,15 +1257,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2;
}
static constexpr auto d0_block_write_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_read_desc_n0_n1_m0_m1_m2 =
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_block_dst_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_src_desc_n0_n1_m0_m1_m2 =
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
static constexpr auto& d0grad_block_dst_desc_n0_n1_m0_m1_m2 =
d0_block_src_desc_n0_n1_m0_m1_m2;
static constexpr auto& d0grad_block_src_desc_m0_n0_m1_m2_n1_m3 =
d0_block_dst_desc_m0_n0_m1_m2_n1_m3;
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
......@@ -1278,12 +1286,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_write_desc_m0_n0_m1_m2_n1_m3), // DstDesc
decltype(d0_block_dst_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim
5, // DstVectorDim
D0BlockTransferSrcScalarPerVector, // SrcScalarPerVector
4, // SrcScalarPerVector
4, // DstScalarPerVector
1,
1,
......@@ -1291,16 +1299,59 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
true, // DstResetCoord
1>;
using D0ThreadWiseCopy =
using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_block_src_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim
2, // SrcScalarPerVector
4, // SrcScalarPerVector
2>;
using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_),
decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim
4, // ScalarPerVector
InMemoryDataOperationEnum::Set, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
using D0GradBlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths
Sequence<1,
1,
1,
BlockSize / NThreadClusterLengths,
NThreadClusterLengths,
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0grad_block_src_desc_m0_n0_m1_m2_n1_m3), // SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc
Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder
Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder
5, // SrcVectorDim
4, // DstVectorDim
4, // SrcScalarPerVector
D0BlockTransferSrcScalarPerVector, // DstScalarPerVector
1,
1,
true,
true, // DstResetCoord
1>;
};
struct SharedMemTrait
......@@ -1335,11 +1386,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
q_block_space_size_aligned.value;
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset =
(k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value) *
sizeof(GemmDataType) / D0Loader::template TypeTransform<D0DataType>::Size;
sizeof(GemmDataType) / D0Operator::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
......@@ -1356,7 +1407,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
sizeof(GemmDataType);
const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Loader::template TypeTransform<D0DataType>::Size0;
D0Operator::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
......@@ -1379,6 +1430,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
......@@ -1846,17 +1898,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// gemm0 M loop
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0
auto d0_block_copy_global_to_lds = typename D0Loader::D0BlockwiseCopy(
auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopyGlobalToLds(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3,
D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy(
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto& d0grad_grid_desc_m0_n0_m1_m2_n1_m3 = d0_grid_desc_m0_n0_m1_m2_n1_m3;
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout});
auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
if constexpr(Deterministic)
{
block_sync_lds();
......@@ -1991,6 +2057,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// add bias
if constexpr(!is_same<D0DataType, void>::value)
{
if(p_d0_grid != nullptr)
{
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc();
......@@ -1999,10 +2067,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Loader::d0_thread_desc_.GetElementSpaceSize());
D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
// load data to lds
......@@ -2013,13 +2081,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds();
// read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Loader::d0_block_read_desc_n0_n1_m0_m1_m2,
d0_thread_copy_lds_to_vgpr.Run(D0Operator::d0_block_src_desc_n0_n1_m0_m1_m2,
make_tuple(I0, I0, I0, I0, I0),
d0_block_buf,
D0Loader::d0_thread_desc_,
D0Operator::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf);
......@@ -2034,7 +2102,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
});
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
}
// P_i: = softmax(scalar * S_i:)
......@@ -2125,6 +2195,45 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
: y_dot_ygrad_thread_buf[Number<m>{}]);
});
// output bias grad
if constexpr(!is_same<D0DataType, void>::value)
{
if(p_d0grad_grid != nullptr)
{
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf,
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
d0grad_block_buf);
block_sync_lds();
// write data from lds to global
d0grad_block_copy_lds_to_global.Run(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
d0grad_block_buf,
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
d0grad_grid_buf,
I0);
d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
});
d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
}
// gemm dV
// dV = P_drop^T * dY
{
......
......@@ -96,6 +96,10 @@ template <typename InputDataType,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
static_assert(Gemm1NPerBlock % KPerBlock == 0);
static_assert(MPerBlock % Gemm1KPerBlock == 0);
static_assert(NPerBlock % Gemm2KPerBlock == 0);
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......@@ -1292,7 +1296,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using D0GridDescriptor_M0_N0_M1_M2_N1_M3 =
remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>;
struct D0Loader
struct D0Operator
{
template <typename DataType>
struct TypeTransform
......@@ -1312,13 +1316,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static_assert(NPerXdl == 32);
static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock");
__host__ __device__ static constexpr auto GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3()
__host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, D0M1, Number<NPerBlock>{}, D0M2));
}
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2()
__host__ __device__ static constexpr auto GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2()
{
constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<NPerBlock>{}, D0M2));
......@@ -1333,15 +1336,20 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2;
}
static constexpr auto d0_block_write_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_read_desc_n0_n1_m0_m1_m2 =
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_block_dst_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_src_desc_n0_n1_m0_m1_m2 =
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
static constexpr auto& d0grad_block_dst_desc_n0_n1_m0_m1_m2 =
d0_block_src_desc_n0_n1_m0_m1_m2;
static constexpr auto& d0grad_block_src_desc_m0_n0_m1_m2_n1_m3 =
d0_block_dst_desc_m0_n0_m1_m2_n1_m3;
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
......@@ -1357,12 +1365,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_write_desc_m0_n0_m1_m2_n1_m3), // DstDesc
decltype(d0_block_dst_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim
5, // DstVectorDim
D0BlockTransferSrcScalarPerVector, // SrcScalarPerVector
4, // SrcScalarPerVector
4, // DstScalarPerVector
1,
1,
......@@ -1370,16 +1378,59 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
true, // DstResetCoord
1>;
using D0ThreadWiseCopy =
using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_block_src_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim
2, // SrcScalarPerVector
4, // SrcScalarPerVector
2>;
using D0GradThreadwiseCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_),
decltype(d0grad_block_dst_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim
4, // ScalarPerVector
InMemoryDataOperationEnum::Set, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
using D0GradBlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths
Sequence<1,
1,
1,
BlockSize / NThreadClusterLengths,
NThreadClusterLengths,
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0grad_block_src_desc_m0_n0_m1_m2_n1_m3), // SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc
Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder
Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder
5, // SrcVectorDim
4, // DstVectorDim
4, // SrcScalarPerVector
D0BlockTransferSrcScalarPerVector, // DstScalarPerVector
1,
1,
true,
true, // DstResetCoord
1>;
};
struct SharedMemTrait
......@@ -1414,10 +1465,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset =
k_block_space_size_aligned.value * sizeof(GemmDataType) /
D0Loader::template TypeTransform<D0DataType>::Size;
D0Operator::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
......@@ -1442,7 +1493,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
sizeof(GemmDataType);
const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Loader::template TypeTransform<D0DataType>::Size0;
D0Operator::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
......@@ -1470,6 +1521,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
......@@ -1967,17 +2019,31 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0
auto d0_block_copy_global_to_lds = typename D0Loader::D0BlockwiseCopy(
auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopyGlobalToLds(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3,
D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy(
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto& d0grad_grid_desc_m0_n0_m1_m2_n1_m3 = d0_grid_desc_m0_n0_m1_m2_n1_m3;
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0GradThreadwiseCopyVgprToLds(
D0Operator::d0grad_block_dst_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout});
auto d0grad_block_copy_lds_to_global = typename D0Operator::D0GradBlockwiseCopyLdsToGlobal(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
if constexpr(Deterministic)
{
block_sync_lds();
......@@ -2142,6 +2208,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// add bias
if constexpr(!is_same<D0DataType, void>::value)
{
if(p_d0_grid != nullptr)
{
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc();
......@@ -2150,11 +2218,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Loader::d0_thread_desc_.GetElementSpaceSize());
ignore = d0_thread_buf;
D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
// load data to lds
......@@ -2165,13 +2232,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
D0Operator::d0_block_dst_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds();
// read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Loader::d0_block_read_desc_n0_n1_m0_m1_m2,
d0_thread_copy_lds_to_vgpr.Run(D0Operator::d0_block_src_desc_n0_n1_m0_m1_m2,
make_tuple(I0, I0, I0, I0, I0),
d0_block_buf,
D0Loader::d0_thread_desc_,
D0Operator::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf);
......@@ -2186,7 +2253,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
});
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
}
// P_i: = softmax(scalar * S_i:)
......@@ -2393,6 +2462,46 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
: y_dot_ygrad_thread_buf[Number<m>{}]);
});
// output bias grad
if constexpr(!is_same<D0DataType, void>::value)
{
if(p_d0grad_grid != nullptr)
{
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0grad_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf,
D0Operator::d0_block_src_desc_n0_n1_m0_m1_m2,
d0grad_block_buf);
block_sync_lds();
// write data from lds to global
d0grad_block_copy_lds_to_global.Run(
D0Operator::d0grad_block_src_desc_m0_n0_m1_m2_n1_m3,
d0grad_block_buf,
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
d0grad_grid_buf,
I0);
d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
});
d0grad_block_copy_lds_to_global.MoveDstSliceWindow(
d0grad_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
}
SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
s_blockwise_gemm.GetWaveIdx()[I1]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment