Unverified Commit 4033f5df authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #949 from ROCmSoftwarePlatform/mha-train-mqagqa

Grouped Query Attention/Multi Query Attention
parents 18be6bc9 957d5dee
...@@ -269,14 +269,15 @@ int run(int argc, char* argv[]) ...@@ -269,14 +269,15 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape // Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
ck::index_t M = 512; ck::index_t M = 512;
ck::index_t N = 512; ck::index_t N = 512;
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 4; ck::index_t G0 = 4;
ck::index_t G1 = 6; ck::index_t G1Q = 6; // h_q
ck::index_t G1KV = 6; // h_kv
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
...@@ -295,32 +296,33 @@ int run(int argc, char* argv[]) ...@@ -295,32 +296,33 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 14)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
K = std::stoi(argv[6]); K = std::stoi(argv[6]);
O = std::stoi(argv[7]); O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1Q = std::stoi(argv[9]);
G1KV = std::stoi(argv[10]);
p_drop = std::stof(argv[10]); p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[13]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4 to 10: M, N, K, O, G0, G1Q, G1KV\n");
printf("arg10: scale (alpha)\n"); printf("arg11: p_drop\n");
printf("arg11 to 12: input / output permute\n"); printf("arg12 to 13: input / output permute\n");
exit(0); exit(0);
} }
...@@ -337,7 +339,8 @@ int run(int argc, char* argv[]) ...@@ -337,7 +339,8 @@ int run(int argc, char* argv[])
std::cout << "K: " << K << std::endl; std::cout << "K: " << K << std::endl;
std::cout << "O: " << O << std::endl; std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl; std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl; std::cout << "G1Q: " << G1Q << std::endl;
std::cout << "G1KV: " << G1KV << std::endl;
std::cout << "alpha: " << alpha << std::endl; std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl; std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl; std::cout << "output_permute: " << output_permute << std::endl;
...@@ -345,45 +348,57 @@ int run(int argc, char* argv[]) ...@@ -345,45 +348,57 @@ int run(int argc, char* argv[])
std::cout << "seed: " << seed << std::endl; std::cout << "seed: " << seed << std::endl;
std::cout << "offset: " << offset << std::endl; std::cout << "offset: " << offset << std::endl;
const ck::index_t BatchCount = G0 * G1; const ck::index_t BatchCount = G0 * G1Q;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides = std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // Q layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K] : std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // Q layout [G0, G1Q, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides = std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // K layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1} // K layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // K layout [G0, G1, N, K] : std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // K layout [G0, G1KV, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides = std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // V layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O} // V layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // V layout [G0, G1, N, O] : std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // V layout [G0, G1KV, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides = std::vector<ck::index_t> y_gs_ms_os_strides =
output_permute output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // Y layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O] : std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // Y layout [G0, G1Q, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides = std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1Q, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1Q * K, K, G1Q * K, 1} // KGrad layout [G0, N, G1Q, K]
: std::vector<ck::index_t>{G1Q * N * K, N * K, K, 1}; // KGrad layout [G0, G1Q, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1Q, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1Q * O, O, 1, G1Q * O} // VGrad layout [G0, N, G1Q, O]
: std::vector<ck::index_t>{G1Q * N * O, N * O, 1, O}; // VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass // The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...) // Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...))) // = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...))) // = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^ // ^^^^^^^^^^^^^^^^^^^^^
// LSE // LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M] std::vector<ck::index_t> lse_gs_ms_strides{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
...@@ -392,6 +407,8 @@ int run(int argc, char* argv[]) ...@@ -392,6 +407,8 @@ int run(int argc, char* argv[])
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks(kgrad_gs_ns_ks_lengths, kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns(vgrad_gs_os_ns_lengths, vgrad_gs_os_ns_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
...@@ -399,6 +416,8 @@ int run(int argc, char* argv[]) ...@@ -399,6 +416,8 @@ int run(int argc, char* argv[])
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
std::cout << "kgrad_gs_ns_ks: " << kgrad_gs_ns_ks.mDesc << std::endl;
std::cout << "vgrad_gs_os_ns: " << vgrad_gs_os_ns.mDesc << std::endl;
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0}); z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
switch(init_method) switch(init_method)
...@@ -432,14 +451,14 @@ int run(int argc, char* argv[]) ...@@ -432,14 +451,14 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1q, m, o]
// dO dot O = [0; 1; 2; ...] // dO dot O = [0; 1; 2; ...]
break; break;
case 6: case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1q, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -452,7 +471,8 @@ int run(int argc, char* argv[]) ...@@ -452,7 +471,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1q, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -474,11 +494,21 @@ int run(int argc, char* argv[]) ...@@ -474,11 +494,21 @@ int run(int argc, char* argv[])
Tensor<LSEDataType> lse_g_m({BatchCount, M}); Tensor<LSEDataType> lse_g_m({BatchCount, M});
q_gs_ms_ks.ForEach( q_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { q_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach( k_g_n_k.ForEach([&](auto& self, auto idx) {
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); const size_t& g0 = idx[0] / G1Q;
v_gs_os_ns.ForEach( const size_t& g1q = idx[0] % G1Q;
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); }); const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = k_gs_ns_ks(g0, g1kv, idx[1], idx[2]);
});
v_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = v_gs_os_ns(g0, g1kv, idx[2], idx[1]);
});
// qkv gradients have the same descriptor as with qkv // qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
...@@ -488,8 +518,8 @@ int run(int argc, char* argv[]) ...@@ -488,8 +518,8 @@ int run(int argc, char* argv[])
DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize()); DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize());
DeviceMem qgrad_device_buf(sizeof(OutputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem qgrad_device_buf(sizeof(OutputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem kgrad_device_buf(sizeof(OutputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem kgrad_device_buf(sizeof(OutputDataType) * kgrad_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(OutputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem vgrad_device_buf(sizeof(OutputDataType) * vgrad_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
q_device_buf.ToDevice(q_gs_ms_ks.mData.data()); q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
...@@ -513,8 +543,8 @@ int run(int argc, char* argv[]) ...@@ -513,8 +543,8 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
nullptr, // p_acc0_bias; nullptr, // p_acc0_bias;
nullptr, // p_acc1_bias; nullptr, // p_acc1_bias;
nullptr, nullptr,
nullptr, nullptr,
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
...@@ -528,6 +558,10 @@ int run(int argc, char* argv[]) ...@@ -528,6 +558,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...@@ -560,8 +594,8 @@ int run(int argc, char* argv[]) ...@@ -560,8 +594,8 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
nullptr, // p_acc0_bias; nullptr, // p_acc0_bias;
nullptr, // p_acc1_bias; nullptr, // p_acc1_bias;
nullptr, nullptr,
nullptr, nullptr,
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
...@@ -575,6 +609,10 @@ int run(int argc, char* argv[]) ...@@ -575,6 +609,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...@@ -614,7 +652,7 @@ int run(int argc, char* argv[]) ...@@ -614,7 +652,7 @@ int run(int argc, char* argv[])
// copy z matirx data form device // copy z matirx data form device
z_device_buf.FromDevice(z_gs_ms_ns.mData.data()); z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
z_gs_ms_ns.ForEach( z_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { z_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx); });
// std::cout << "z_g_m_n ref:\n" << z_g_m_n; // std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool pass = true; bool pass = true;
...@@ -634,10 +672,10 @@ int run(int argc, char* argv[]) ...@@ -634,10 +672,10 @@ int run(int argc, char* argv[])
p_dropout_in_uint8_t, p_dropout_in_uint8_t,
rp_dropout); rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) { y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); self(idx) = y_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]);
}); });
lse_gs_ms.ForEach( lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); }); [&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1Q + idx[1], idx[2]); });
y_device_buf.ToDevice(y_gs_ms_os.mData.data()); y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data()); lse_device_buf.ToDevice(lse_gs_ms.mData.data());
...@@ -655,7 +693,7 @@ int run(int argc, char* argv[]) ...@@ -655,7 +693,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> ygrad_dot_y_g_m({BatchCount, M}); Tensor<InputDataType> ygrad_dot_y_g_m({BatchCount, M});
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) { ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); ygrad_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
#if PRINT_HOST #if PRINT_HOST
...@@ -757,12 +795,16 @@ int run(int argc, char* argv[]) ...@@ -757,12 +795,16 @@ int run(int argc, char* argv[])
#endif #endif
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(kgrad_gs_ns_ks_lengths,
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(kgrad_gs_ns_ks_lengths,
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides);
qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data()); qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data());
kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data()); kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data());
...@@ -770,26 +812,26 @@ int run(int argc, char* argv[]) ...@@ -770,26 +812,26 @@ int run(int argc, char* argv[])
// permute // permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) { qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]); self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
}); });
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) { kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]); self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
}); });
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) { vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]); self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
}); });
......
...@@ -270,14 +270,15 @@ int run(int argc, char* argv[]) ...@@ -270,14 +270,15 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape // Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
ck::index_t M = 512; ck::index_t M = 512;
ck::index_t N = 512; ck::index_t N = 512;
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 4; ck::index_t G0 = 4;
ck::index_t G1 = 6; ck::index_t G1Q = 6; // h_q
ck::index_t G1KV = 6; // h_kv
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
...@@ -296,32 +297,33 @@ int run(int argc, char* argv[]) ...@@ -296,32 +297,33 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 14)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
K = std::stoi(argv[6]); K = std::stoi(argv[6]);
O = std::stoi(argv[7]); O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1Q = std::stoi(argv[9]);
G1KV = std::stoi(argv[10]);
p_drop = std::stof(argv[10]); p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[13]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4 to 10: M, N, K, O, G0, G1Q, G1KV\n");
printf("arg10: scale (alpha)\n"); printf("arg11: p_drop\n");
printf("arg11 to 12: input / output permute\n"); printf("arg12 to 13: input / output permute\n");
exit(0); exit(0);
} }
...@@ -338,7 +340,8 @@ int run(int argc, char* argv[]) ...@@ -338,7 +340,8 @@ int run(int argc, char* argv[])
std::cout << "K: " << K << std::endl; std::cout << "K: " << K << std::endl;
std::cout << "O: " << O << std::endl; std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl; std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl; std::cout << "G1Q: " << G1Q << std::endl;
std::cout << "G1KV: " << G1KV << std::endl;
std::cout << "alpha: " << alpha << std::endl; std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl; std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl; std::cout << "output_permute: " << output_permute << std::endl;
...@@ -346,45 +349,57 @@ int run(int argc, char* argv[]) ...@@ -346,45 +349,57 @@ int run(int argc, char* argv[])
std::cout << "seed: " << seed << std::endl; std::cout << "seed: " << seed << std::endl;
std::cout << "offset: " << offset << std::endl; std::cout << "offset: " << offset << std::endl;
const ck::index_t BatchCount = G0 * G1; const ck::index_t BatchCount = G0 * G1Q;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides = std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // Q layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K] : std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // Q layout [G0, G1Q, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides = std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // K layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1} // K layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // K layout [G0, G1, N, K] : std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // K layout [G0, G1KV, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides = std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // V layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O} // V layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // V layout [G0, G1, N, O] : std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // V layout [G0, G1KV, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides = std::vector<ck::index_t> y_gs_ms_os_strides =
output_permute output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // Y layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O] : std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // Y layout [G0, G1Q, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides = std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1Q, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1Q * K, K, G1Q * K, 1} // KGrad layout [G0, N, G1Q, K]
: std::vector<ck::index_t>{G1Q * N * K, N * K, K, 1}; // KGrad layout [G0, G1Q, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1Q, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1Q * O, O, 1, G1Q * O} // VGrad layout [G0, N, G1Q, O]
: std::vector<ck::index_t>{G1Q * N * O, N * O, 1, O}; // VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass // The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...) // Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...))) // = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...))) // = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^ // ^^^^^^^^^^^^^^^^^^^^^
// LSE // LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M] std::vector<ck::index_t> lse_gs_ms_strides{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
...@@ -394,6 +409,8 @@ int run(int argc, char* argv[]) ...@@ -394,6 +409,8 @@ int run(int argc, char* argv[])
Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<DDataType> d_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<DDataType> d_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks(kgrad_gs_ns_ks_lengths, kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns(vgrad_gs_os_ns_lengths, vgrad_gs_os_ns_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
...@@ -402,6 +419,8 @@ int run(int argc, char* argv[]) ...@@ -402,6 +419,8 @@ int run(int argc, char* argv[])
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
std::cout << "d_gs_ms_os: " << d_gs_ms.mDesc << std::endl; std::cout << "d_gs_ms_os: " << d_gs_ms.mDesc << std::endl;
std::cout << "kgrad_gs_ns_ks: " << kgrad_gs_ns_ks.mDesc << std::endl;
std::cout << "vgrad_gs_os_ns: " << vgrad_gs_os_ns.mDesc << std::endl;
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0}); z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
switch(init_method) switch(init_method)
...@@ -435,14 +454,14 @@ int run(int argc, char* argv[]) ...@@ -435,14 +454,14 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1q, m, o]
// dO dot O = [0; 1; 2; ...] // dO dot O = [0; 1; 2; ...]
break; break;
case 6: case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1q, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -455,7 +474,8 @@ int run(int argc, char* argv[]) ...@@ -455,7 +474,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1q, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -477,11 +497,21 @@ int run(int argc, char* argv[]) ...@@ -477,11 +497,21 @@ int run(int argc, char* argv[])
Tensor<LSEDataType> lse_g_m({BatchCount, M}); Tensor<LSEDataType> lse_g_m({BatchCount, M});
q_gs_ms_ks.ForEach( q_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { q_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx); });
k_gs_ns_ks.ForEach( k_g_n_k.ForEach([&](auto& self, auto idx) {
[&](auto& self, auto idx) { k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); const size_t& g0 = idx[0] / G1Q;
v_gs_os_ns.ForEach( const size_t& g1q = idx[0] % G1Q;
[&](auto& self, auto idx) { v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); }); const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = k_gs_ns_ks(g0, g1kv, idx[1], idx[2]);
});
v_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = v_gs_os_ns(g0, g1kv, idx[2], idx[1]);
});
// qkv gradients have the same descriptor as with qkv // qkv gradients have the same descriptor as with qkv
DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem q_device_buf(sizeof(InputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
...@@ -491,8 +521,8 @@ int run(int argc, char* argv[]) ...@@ -491,8 +521,8 @@ int run(int argc, char* argv[])
DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize()); DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize());
DeviceMem qgrad_device_buf(sizeof(OutputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem qgrad_device_buf(sizeof(OutputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem kgrad_device_buf(sizeof(OutputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem kgrad_device_buf(sizeof(OutputDataType) * kgrad_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(OutputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem vgrad_device_buf(sizeof(OutputDataType) * vgrad_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms.mDesc.GetElementSpaceSize());
...@@ -533,6 +563,10 @@ int run(int argc, char* argv[]) ...@@ -533,6 +563,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...@@ -581,6 +615,10 @@ int run(int argc, char* argv[]) ...@@ -581,6 +615,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...@@ -620,7 +658,7 @@ int run(int argc, char* argv[]) ...@@ -620,7 +658,7 @@ int run(int argc, char* argv[])
// copy z matirx data form device // copy z matirx data form device
z_device_buf.FromDevice(z_gs_ms_ns.mData.data()); z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
z_gs_ms_ns.ForEach( z_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); }); [&](auto& self, auto idx) { z_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx); });
// std::cout << "z_g_m_n ref:\n" << z_g_m_n; // std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool pass = true; bool pass = true;
...@@ -640,10 +678,10 @@ int run(int argc, char* argv[]) ...@@ -640,10 +678,10 @@ int run(int argc, char* argv[])
p_dropout_in_uint8_t, p_dropout_in_uint8_t,
rp_dropout); rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) { y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); self(idx) = y_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]);
}); });
lse_gs_ms.ForEach( lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); }); [&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1Q + idx[1], idx[2]); });
y_device_buf.ToDevice(y_gs_ms_os.mData.data()); y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data()); lse_device_buf.ToDevice(lse_gs_ms.mData.data());
...@@ -661,7 +699,7 @@ int run(int argc, char* argv[]) ...@@ -661,7 +699,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> ygrad_dot_y_g_m({BatchCount, M}); Tensor<InputDataType> ygrad_dot_y_g_m({BatchCount, M});
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) { ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); ygrad_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
#if PRINT_HOST #if PRINT_HOST
...@@ -763,12 +801,16 @@ int run(int argc, char* argv[]) ...@@ -763,12 +801,16 @@ int run(int argc, char* argv[])
#endif #endif
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(kgrad_gs_ns_ks_lengths,
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(kgrad_gs_ns_ks_lengths,
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides);
qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data()); qgrad_device_buf.FromDevice(qgrad_gs_ms_ks_device_result.mData.data());
kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data()); kgrad_device_buf.FromDevice(kgrad_gs_ns_ks_device_result.mData.data());
...@@ -776,26 +818,26 @@ int run(int argc, char* argv[]) ...@@ -776,26 +818,26 @@ int run(int argc, char* argv[])
// permute // permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) { qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]); self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
}); });
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) { kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]); self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
}); });
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) { vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]); self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
}); });
......
...@@ -299,14 +299,15 @@ int run(int argc, char* argv[]) ...@@ -299,14 +299,15 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape // Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
ck::index_t N = 500; // 512 ck::index_t N = 500; // 512
ck::index_t M = 500; // 512 ck::index_t M = 500; // 512
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 4; // 54 ck::index_t G0 = 4;
ck::index_t G1 = 6; // 16 ck::index_t G1Q = 6; // h_q
ck::index_t G1KV = 6; // h_kv
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
...@@ -325,32 +326,33 @@ int run(int argc, char* argv[]) ...@@ -325,32 +326,33 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 14)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
K = std::stoi(argv[6]); K = std::stoi(argv[6]);
O = std::stoi(argv[7]); O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1Q = std::stoi(argv[9]);
G1KV = std::stoi(argv[10]);
p_drop = std::stof(argv[10]); p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[13]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4 to 10: M, N, K, O, G0, G1Q, G1KV\n");
printf("arg10: scale (alpha)\n"); printf("arg11: p_drop\n");
printf("arg11 to 12: input / output permute\n"); printf("arg12 to 13: input / output permute\n");
exit(0); exit(0);
} }
...@@ -367,7 +369,8 @@ int run(int argc, char* argv[]) ...@@ -367,7 +369,8 @@ int run(int argc, char* argv[])
std::cout << "K: " << K << std::endl; std::cout << "K: " << K << std::endl;
std::cout << "O: " << O << std::endl; std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl; std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl; std::cout << "G1Q: " << G1Q << std::endl;
std::cout << "G1KV: " << G1KV << std::endl;
std::cout << "alpha: " << alpha << std::endl; std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl; std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl; std::cout << "output_permute: " << output_permute << std::endl;
...@@ -375,45 +378,57 @@ int run(int argc, char* argv[]) ...@@ -375,45 +378,57 @@ int run(int argc, char* argv[])
std::cout << "seed: " << seed << std::endl; std::cout << "seed: " << seed << std::endl;
std::cout << "offset: " << offset << std::endl; std::cout << "offset: " << offset << std::endl;
const ck::index_t BatchCount = G0 * G1; const ck::index_t BatchCount = G0 * G1Q;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides = std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // Q layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K] : std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // Q layout [G0, G1Q, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides = std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // K layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1} // K layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // K layout [G0, G1, N, K] : std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // K layout [G0, G1KV, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides = std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // V layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O} // V layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // V layout [G0, G1, N, O] : std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // V layout [G0, G1KV, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides = std::vector<ck::index_t> y_gs_ms_os_strides =
output_permute output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // Y layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O] : std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // Y layout [G0, G1Q, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides = std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1Q, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1Q * K, K, G1Q * K, 1} // KGrad layout [G0, N, G1Q, K]
: std::vector<ck::index_t>{G1Q * N * K, N * K, K, 1}; // KGrad layout [G0, G1Q, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1Q, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1Q * O, O, 1, G1Q * O} // VGrad layout [G0, N, G1Q, O]
: std::vector<ck::index_t>{G1Q * N * O, N * O, 1, O}; // VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass // The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...) // Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...))) // = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...))) // = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^ // ^^^^^^^^^^^^^^^^^^^^^
// LSE // LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M] std::vector<ck::index_t> lse_gs_ms_strides{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
...@@ -424,8 +439,10 @@ int run(int argc, char* argv[]) ...@@ -424,8 +439,10 @@ int run(int argc, char* argv[])
Tensor<InputDataType> y_gs_ms_os_device_result(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os_device_result(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(kgrad_gs_ns_ks_lengths,
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
...@@ -467,14 +484,14 @@ int run(int argc, char* argv[]) ...@@ -467,14 +484,14 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1q, m, o]
// dO dot O = [0; 1; 2; ...] // dO dot O = [0; 1; 2; ...]
break; break;
case 6: case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1q, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -487,7 +504,8 @@ int run(int argc, char* argv[]) ...@@ -487,7 +504,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1q, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -612,6 +630,10 @@ int run(int argc, char* argv[]) ...@@ -612,6 +630,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...@@ -656,8 +678,10 @@ int run(int argc, char* argv[]) ...@@ -656,8 +678,10 @@ int run(int argc, char* argv[])
Tensor<InputDataType> y_gs_ms_os_host_result(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os_host_result(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(kgrad_gs_ns_ks_lengths,
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides);
Tensor<InputDataType> q_g_m_k({BatchCount, M, K}); Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<InputDataType> k_g_n_k({BatchCount, N, K}); Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
...@@ -760,6 +784,10 @@ int run(int argc, char* argv[]) ...@@ -760,6 +784,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
...@@ -793,16 +821,24 @@ int run(int argc, char* argv[]) ...@@ -793,16 +821,24 @@ int run(int argc, char* argv[])
} }
q_gs_ms_ks.ForEach([&](auto& self, auto idx) { q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); q_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
k_gs_ns_ks.ForEach([&](auto& self, auto idx) { k_g_n_k.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = k_gs_ns_ks(g0, g1kv, idx[1], idx[2]);
}); });
v_gs_os_ns.ForEach([&](auto& self, auto idx) { v_g_n_o.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = v_gs_os_ns(g0, g1kv, idx[2], idx[1]);
}); });
z_fwd_gs_ms_ns.ForEach([&](auto& self, auto idx) { z_fwd_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_fwd_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); z_fwd_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
run_attention_fwd_host(q_g_m_k, run_attention_fwd_host(q_g_m_k,
...@@ -819,10 +855,10 @@ int run(int argc, char* argv[]) ...@@ -819,10 +855,10 @@ int run(int argc, char* argv[])
rp_dropout); rp_dropout);
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) { ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); ygrad_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
z_bwd_gs_ms_ns.ForEach([&](auto& self, auto idx) { z_bwd_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_bwd_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); z_bwd_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
#if PRINT_HOST #if PRINT_HOST
...@@ -925,42 +961,42 @@ int run(int argc, char* argv[]) ...@@ -925,42 +961,42 @@ int run(int argc, char* argv[])
// permute // permute
y_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { y_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = y_g_m_o(g, idx[2], idx[3]); self(idx) = y_g_m_o(g, idx[2], idx[3]);
}); });
lse_gs_ms_host_result.ForEach([&](auto& self, auto idx) { lse_gs_ms_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = lse_g_m(g, idx[2]); self(idx) = lse_g_m(g, idx[2]);
}); });
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) { qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]); self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
}); });
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) { kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]); self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
}); });
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) { vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]); self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
}); });
......
...@@ -268,10 +268,11 @@ int run(int argc, char* argv[]) ...@@ -268,10 +268,11 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape // Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
float alpha = 1.f / std::sqrt(DIM); float alpha = 1.f / std::sqrt(DIM);
float p_drop = 0.0; float p_drop = 0.0;
int h_ratio = 1; // G1Q / G1KV
bool input_permute = true; bool input_permute = true;
bool output_permute = true; bool output_permute = true;
...@@ -289,25 +290,26 @@ int run(int argc, char* argv[]) ...@@ -289,25 +290,26 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 7) else if(argc == 8)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
p_drop = std::stof(argv[4]); p_drop = std::stof(argv[4]);
h_ratio = std::stof(argv[5]);
input_permute = std::stoi(argv[5]); input_permute = std::stoi(argv[6]);
output_permute = std::stoi(argv[6]); output_permute = std::stoi(argv[7]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4: p_drop\n");
printf("arg10: scale (alpha)\n"); printf("arg5: h_ratio\n");
printf("arg11 to 12: input / output permute\n"); printf("arg6 to 7: input / output permute\n");
exit(0); exit(0);
} }
...@@ -367,49 +369,65 @@ int run(int argc, char* argv[]) ...@@ -367,49 +369,65 @@ int run(int argc, char* argv[])
std::size_t flop = 0, num_byte = 0; std::size_t flop = 0, num_byte = 0;
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int M = 128 * (rand() % 8) + (rand() % 128); int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128); int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM; int K = DIM;
int O = DIM; int O = DIM;
int G0 = rand() % 4 + 1; int G0 = rand() % 4 + 1;
int G1 = rand() % 4 + 1; int G1KV = rand() % 4 + 1;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K}; int G1Q = G1KV * h_ratio;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides = std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // Q layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K] : std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // Q layout [G0, G1Q, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides = std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // K layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1}
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // K layout [G0, G1, N, K] // K layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // K layout [G0, G1KV, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides = std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // V layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O}
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // V layout [G0, G1, N, O] // V layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // V layout [G0, G1KV, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides = std::vector<ck::index_t> y_gs_ms_os_strides =
output_permute output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // Y layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O] : std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // Y layout [G0, G1Q, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides = std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1Q, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute ? std::vector<ck::index_t>{N * G1Q * K, K, G1Q * K, 1}
// KGrad layout [G0, N, G1Q, K]
: std::vector<ck::index_t>{
G1Q * N * K, N * K, K, 1}; // KGrad layout [G0, G1Q, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1Q, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute ? std::vector<ck::index_t>{N * G1Q * O, O, 1, G1Q * O}
// VGrad layout [G0, N, G1Q, O]
: std::vector<ck::index_t>{
G1Q * N * O, N * O, 1, O}; // VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward // The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...) // pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...))) // = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...))) // = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^ // ^^^^^^^^^^^^^^^^^^^^^
// LSE // LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M] std::vector<ck::index_t> lse_gs_ms_strides{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
problem_descs.push_back({ problem_descs.push_back({
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
...@@ -423,13 +441,17 @@ int run(int argc, char* argv[]) ...@@ -423,13 +441,17 @@ int run(int argc, char* argv[])
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
lse_gs_ms_strides, lse_gs_ms_strides,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
}); });
int BatchCount = G0 * G1; int BatchCount = G0 * G1Q;
flop += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount; flop += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE // Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte += (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N + num_byte += (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
...@@ -446,6 +468,8 @@ int run(int argc, char* argv[]) ...@@ -446,6 +468,8 @@ int run(int argc, char* argv[])
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks(kgrad_gs_ns_ks_lengths, kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns(vgrad_gs_os_ns_lengths, vgrad_gs_os_ns_strides);
if(i < 4) if(i < 4)
{ {
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
...@@ -454,6 +478,8 @@ int run(int argc, char* argv[]) ...@@ -454,6 +478,8 @@ int run(int argc, char* argv[])
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
std::cout << "kgrad_gs_ns_ks: " << kgrad_gs_ns_ks.mDesc << std::endl;
std::cout << "vgrad_gs_os_ns: " << vgrad_gs_os_ns.mDesc << std::endl;
} }
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0}); z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
switch(init_method) switch(init_method)
...@@ -487,14 +513,16 @@ int run(int argc, char* argv[]) ...@@ -487,14 +513,16 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_Sequential<2>{}); // dy[g0, g1q, m, o]
// dO dot O = [0; 1; 2; ...] // dO dot O = [0; 1; 2; ...]
break; break;
case 6: case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_Sequential<3>{}); // dy[g0, g1q, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -508,7 +536,7 @@ int run(int argc, char* argv[]) ...@@ -508,7 +536,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue( ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o] GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1q, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -529,13 +557,21 @@ int run(int argc, char* argv[]) ...@@ -529,13 +557,21 @@ int run(int argc, char* argv[])
Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N}); Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
q_gs_ms_ks.ForEach([&](auto& self, auto idx) { q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); q_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
k_gs_ns_ks.ForEach([&](auto& self, auto idx) { k_g_n_k.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = k_gs_ns_ks(g0, g1kv, idx[1], idx[2]);
}); });
v_gs_os_ns.ForEach([&](auto& self, auto idx) { v_g_n_o.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = v_gs_os_ns(g0, g1kv, idx[2], idx[1]);
}); });
q_g_m_ks.push_back(q_g_m_k); q_g_m_ks.push_back(q_g_m_k);
...@@ -554,6 +590,8 @@ int run(int argc, char* argv[]) ...@@ -554,6 +590,8 @@ int run(int argc, char* argv[])
z_tensors.push_back(z_gs_ms_ns); z_tensors.push_back(z_gs_ms_ns);
lse_tensors.push_back(lse_gs_ms); lse_tensors.push_back(lse_gs_ms);
ygrad_tensors.push_back(ygrad_gs_ms_os); ygrad_tensors.push_back(ygrad_gs_ms_os);
kgrad_tensors.push_back(kgrad_gs_ns_ks);
vgrad_tensors.push_back(vgrad_gs_os_ns);
q_tensors_device.emplace_back( q_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(InputDataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
k_tensors_device.emplace_back( k_tensors_device.emplace_back(
...@@ -568,10 +606,10 @@ int run(int argc, char* argv[]) ...@@ -568,10 +606,10 @@ int run(int argc, char* argv[])
std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(LSEDataType) * lse_gs_ms.GetElementSpaceSize()));
qgrad_tensors_device.emplace_back( qgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back( kgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * k_gs_ns_ks.GetElementSpaceSize())); sizeof(OutputDataType) * kgrad_gs_ns_ks.GetElementSpaceSize()));
vgrad_tensors_device.emplace_back( vgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * v_gs_os_ns.GetElementSpaceSize())); sizeof(OutputDataType) * vgrad_gs_os_ns.GetElementSpaceSize()));
ygrad_tensors_device.emplace_back( ygrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(InputDataType) * y_gs_ms_os.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * y_gs_ms_os.GetElementSpaceSize()));
q_tensors_device.back()->ToDevice(q_gs_ms_ks.data()); q_tensors_device.back()->ToDevice(q_gs_ms_ks.data());
...@@ -674,11 +712,11 @@ int run(int argc, char* argv[]) ...@@ -674,11 +712,11 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int G1 = v_tensors[i].GetLengths()[1]; int G1Q = q_tensors[i].GetLengths()[1];
// copy z matirx data form device // copy z matirx data form device
z_tensors_device[i]->FromDevice(z_tensors[i].mData.data()); z_tensors_device[i]->FromDevice(z_tensors[i].mData.data());
z_tensors[i].ForEach([&](auto& self, auto idx) { z_tensors[i].ForEach([&](auto& self, auto idx) {
z_g_m_ns[i](idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); z_g_m_ns[i](idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
run_attention_fwd_host(q_g_m_ks[i], run_attention_fwd_host(q_g_m_ks[i],
k_g_n_ks[i], k_g_n_ks[i],
...@@ -694,11 +732,11 @@ int run(int argc, char* argv[]) ...@@ -694,11 +732,11 @@ int run(int argc, char* argv[])
rp_dropout); rp_dropout);
y_tensors[i].ForEach([&](auto& self, auto idx) { y_tensors[i].ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_os[i](idx[0] * G1 + idx[1], idx[2], idx[3]); self(idx) = y_g_m_os[i](idx[0] * G1Q + idx[1], idx[2], idx[3]);
}); });
y_tensors_device[i]->ToDevice(y_tensors[i].data()); y_tensors_device[i]->ToDevice(y_tensors[i].data());
lse_tensors[i].ForEach([&](auto& self, auto idx) { lse_tensors[i].ForEach([&](auto& self, auto idx) {
self(idx) = lse_g_ms[i](idx[0] * G1 + idx[1], idx[2]); self(idx) = lse_g_ms[i](idx[0] * G1Q + idx[1], idx[2]);
}); });
lse_tensors_device[i]->ToDevice(lse_tensors[i].data()); lse_tensors_device[i]->ToDevice(lse_tensors[i].data());
qgrad_tensors_device[i]->SetZero(); qgrad_tensors_device[i]->SetZero();
...@@ -711,13 +749,13 @@ int run(int argc, char* argv[]) ...@@ -711,13 +749,13 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int G0 = v_tensors[i].GetLengths()[0]; int G0 = q_tensors[i].GetLengths()[0];
int G1 = v_tensors[i].GetLengths()[1]; int G1Q = q_tensors[i].GetLengths()[1];
int O = v_tensors[i].GetLengths()[2]; int O = v_tensors[i].GetLengths()[2];
int N = v_tensors[i].GetLengths()[3]; int N = v_tensors[i].GetLengths()[3];
int M = q_tensors[i].GetLengths()[2]; int M = q_tensors[i].GetLengths()[2];
int K = q_tensors[i].GetLengths()[3]; int K = q_tensors[i].GetLengths()[3];
int BatchCount = G0 * G1; int BatchCount = G0 * G1Q;
Tensor<OutputDataType> qgrad_g_m_k({BatchCount, M, K}); Tensor<OutputDataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<OutputDataType> kgrad_g_n_k({BatchCount, N, K}); Tensor<OutputDataType> kgrad_g_n_k({BatchCount, N, K});
Tensor<OutputDataType> vgrad_g_n_o({BatchCount, N, O}); Tensor<OutputDataType> vgrad_g_n_o({BatchCount, N, O});
...@@ -727,7 +765,7 @@ int run(int argc, char* argv[]) ...@@ -727,7 +765,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> ygrad_g_m_o({BatchCount, M, O}); Tensor<InputDataType> ygrad_g_m_o({BatchCount, M, O});
ygrad_tensors[i].ForEach([&](auto& self, auto idx) { ygrad_tensors[i].ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); ygrad_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
auto ref_gemm0_grad = ReferenceGemm0GradInstance{}; auto ref_gemm0_grad = ReferenceGemm0GradInstance{};
auto ref_gemm0_grad_invoker = ref_gemm0_grad.MakeInvoker(); auto ref_gemm0_grad_invoker = ref_gemm0_grad.MakeInvoker();
...@@ -770,43 +808,43 @@ int run(int argc, char* argv[]) ...@@ -770,43 +808,43 @@ int run(int argc, char* argv[])
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(), Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(kgrad_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); kgrad_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_host_result(vgrad_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); vgrad_tensors[i].GetStrides());
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(), Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(kgrad_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); kgrad_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_device_result(vgrad_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); vgrad_tensors[i].GetStrides());
qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data()); qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data());
kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data()); kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data());
vgrad_tensors_device[i]->FromDevice(vgrad_gs_os_ns_device_result.data()); vgrad_tensors_device[i]->FromDevice(vgrad_gs_os_ns_device_result.data());
// permute // permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) { qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]); self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
}); });
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) { kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]); self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
}); });
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) { vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]); self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
}); });
......
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define USING_MASK 0 #define USING_MASK 0
#define DIM 32 // DIM should be a multiple of 8. #define DIM 128 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -269,10 +269,11 @@ int run(int argc, char* argv[]) ...@@ -269,10 +269,11 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape // Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
float alpha = 1.f / std::sqrt(DIM); float alpha = 1.f / std::sqrt(DIM);
float p_drop = 0.0; float p_drop = 0.0;
int h_ratio = 1; // G1Q / G1KV
bool input_permute = true; bool input_permute = true;
bool output_permute = true; bool output_permute = true;
...@@ -290,25 +291,26 @@ int run(int argc, char* argv[]) ...@@ -290,25 +291,26 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 7) else if(argc == 8)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
p_drop = std::stof(argv[4]); p_drop = std::stof(argv[4]);
h_ratio = std::stof(argv[5]);
input_permute = std::stoi(argv[5]); input_permute = std::stoi(argv[6]);
output_permute = std::stoi(argv[6]); output_permute = std::stoi(argv[7]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4: p_drop\n");
printf("arg10: scale (alpha)\n"); printf("arg5: h_ratio\n");
printf("arg11 to 12: input / output permute\n"); printf("arg6 to 7: input / output permute\n");
exit(0); exit(0);
} }
...@@ -371,49 +373,65 @@ int run(int argc, char* argv[]) ...@@ -371,49 +373,65 @@ int run(int argc, char* argv[])
std::size_t flop = 0, num_byte = 0; std::size_t flop = 0, num_byte = 0;
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int M = 128 * (rand() % 8) + (rand() % 128); int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128); int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM; int K = DIM;
int O = DIM; int O = DIM;
int G0 = rand() % 4 + 1; int G0 = rand() % 4 + 1;
int G1 = rand() % 4 + 1; int G1KV = rand() % 4 + 1;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K}; int G1Q = G1KV * h_ratio;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides = std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // Q layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K] : std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // Q layout [G0, G1Q, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides = std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // K layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1}
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // K layout [G0, G1, N, K] // K layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // K layout [G0, G1KV, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides = std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // V layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O}
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // V layout [G0, G1, N, O] // V layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // V layout [G0, G1KV, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides = std::vector<ck::index_t> y_gs_ms_os_strides =
output_permute output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // Y layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O] : std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // Y layout [G0, G1Q, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides = std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1Q, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute ? std::vector<ck::index_t>{N * G1Q * K, K, G1Q * K, 1}
// KGrad layout [G0, N, G1Q, K]
: std::vector<ck::index_t>{
G1Q * N * K, N * K, K, 1}; // KGrad layout [G0, G1Q, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1Q, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute ? std::vector<ck::index_t>{N * G1Q * O, O, 1, G1Q * O}
// VGrad layout [G0, N, G1Q, O]
: std::vector<ck::index_t>{
G1Q * N * O, N * O, 1, O}; // VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward // The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...) // pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...))) // = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...))) // = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^ // ^^^^^^^^^^^^^^^^^^^^^
// LSE // LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M] std::vector<ck::index_t> lse_gs_ms_strides{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
problem_descs.push_back({ problem_descs.push_back({
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
...@@ -427,13 +445,17 @@ int run(int argc, char* argv[]) ...@@ -427,13 +445,17 @@ int run(int argc, char* argv[])
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
lse_gs_ms_strides, lse_gs_ms_strides,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides}, {}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
}); });
int BatchCount = G0 * G1; int BatchCount = G0 * G1Q;
flop += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount; flop += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE // Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte += (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N + num_byte += (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
...@@ -451,6 +473,8 @@ int run(int argc, char* argv[]) ...@@ -451,6 +473,8 @@ int run(int argc, char* argv[])
Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<DDataType> d_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<DDataType> d_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks(kgrad_gs_ns_ks_lengths, kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns(vgrad_gs_os_ns_lengths, vgrad_gs_os_ns_strides);
if(i < 4) if(i < 4)
{ {
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
...@@ -460,6 +484,8 @@ int run(int argc, char* argv[]) ...@@ -460,6 +484,8 @@ int run(int argc, char* argv[])
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
std::cout << "d_gs_ms_os: " << d_gs_ms.mDesc << std::endl; std::cout << "d_gs_ms_os: " << d_gs_ms.mDesc << std::endl;
std::cout << "kgrad_gs_ns_ks: " << kgrad_gs_ns_ks.mDesc << std::endl;
std::cout << "vgrad_gs_os_ns: " << vgrad_gs_os_ns.mDesc << std::endl;
} }
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0}); z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
switch(init_method) switch(init_method)
...@@ -493,14 +519,16 @@ int run(int argc, char* argv[]) ...@@ -493,14 +519,16 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_Sequential<2>{}); // dy[g0, g1q, m, o]
// dO dot O = [0; 1; 2; ...] // dO dot O = [0; 1; 2; ...]
break; break;
case 6: case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_Sequential<3>{}); // dy[g0, g1q, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -514,7 +542,7 @@ int run(int argc, char* argv[]) ...@@ -514,7 +542,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue( ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o] GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1q, m, o]
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones // O = P V = 0.0039 * ones
...@@ -536,13 +564,21 @@ int run(int argc, char* argv[]) ...@@ -536,13 +564,21 @@ int run(int argc, char* argv[])
Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N}); Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
q_gs_ms_ks.ForEach([&](auto& self, auto idx) { q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); q_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
k_gs_ns_ks.ForEach([&](auto& self, auto idx) { k_g_n_k.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = k_gs_ns_ks(g0, g1kv, idx[1], idx[2]);
}); });
v_gs_os_ns.ForEach([&](auto& self, auto idx) { v_g_n_o.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = v_gs_os_ns(g0, g1kv, idx[2], idx[1]);
}); });
q_g_m_ks.push_back(q_g_m_k); q_g_m_ks.push_back(q_g_m_k);
...@@ -562,6 +598,8 @@ int run(int argc, char* argv[]) ...@@ -562,6 +598,8 @@ int run(int argc, char* argv[])
z_tensors.push_back(z_gs_ms_ns); z_tensors.push_back(z_gs_ms_ns);
lse_tensors.push_back(lse_gs_ms); lse_tensors.push_back(lse_gs_ms);
ygrad_tensors.push_back(ygrad_gs_ms_os); ygrad_tensors.push_back(ygrad_gs_ms_os);
kgrad_tensors.push_back(kgrad_gs_ns_ks);
vgrad_tensors.push_back(vgrad_gs_os_ns);
q_tensors_device.emplace_back( q_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(InputDataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
k_tensors_device.emplace_back( k_tensors_device.emplace_back(
...@@ -578,10 +616,10 @@ int run(int argc, char* argv[]) ...@@ -578,10 +616,10 @@ int run(int argc, char* argv[])
std::make_unique<DeviceMem>(sizeof(DDataType) * d_gs_ms.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(DDataType) * d_gs_ms.GetElementSpaceSize()));
qgrad_tensors_device.emplace_back( qgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back( kgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * k_gs_ns_ks.GetElementSpaceSize())); sizeof(OutputDataType) * kgrad_gs_ns_ks.GetElementSpaceSize()));
vgrad_tensors_device.emplace_back( vgrad_tensors_device.emplace_back(std::make_unique<DeviceMem>(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * v_gs_os_ns.GetElementSpaceSize())); sizeof(OutputDataType) * vgrad_gs_os_ns.GetElementSpaceSize()));
ygrad_tensors_device.emplace_back( ygrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(InputDataType) * y_gs_ms_os.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(InputDataType) * y_gs_ms_os.GetElementSpaceSize()));
q_tensors_device.back()->ToDevice(q_gs_ms_ks.data()); q_tensors_device.back()->ToDevice(q_gs_ms_ks.data());
...@@ -687,11 +725,11 @@ int run(int argc, char* argv[]) ...@@ -687,11 +725,11 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int G1 = v_tensors[i].GetLengths()[1]; int G1Q = q_tensors[i].GetLengths()[1];
// copy z matirx data form device // copy z matirx data form device
z_tensors_device[i]->FromDevice(z_tensors[i].mData.data()); z_tensors_device[i]->FromDevice(z_tensors[i].mData.data());
z_tensors[i].ForEach([&](auto& self, auto idx) { z_tensors[i].ForEach([&](auto& self, auto idx) {
z_g_m_ns[i](idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); z_g_m_ns[i](idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
run_attention_fwd_host(q_g_m_ks[i], run_attention_fwd_host(q_g_m_ks[i],
k_g_n_ks[i], k_g_n_ks[i],
...@@ -707,11 +745,11 @@ int run(int argc, char* argv[]) ...@@ -707,11 +745,11 @@ int run(int argc, char* argv[])
rp_dropout); rp_dropout);
y_tensors[i].ForEach([&](auto& self, auto idx) { y_tensors[i].ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_os[i](idx[0] * G1 + idx[1], idx[2], idx[3]); self(idx) = y_g_m_os[i](idx[0] * G1Q + idx[1], idx[2], idx[3]);
}); });
y_tensors_device[i]->ToDevice(y_tensors[i].data()); y_tensors_device[i]->ToDevice(y_tensors[i].data());
lse_tensors[i].ForEach([&](auto& self, auto idx) { lse_tensors[i].ForEach([&](auto& self, auto idx) {
self(idx) = lse_g_ms[i](idx[0] * G1 + idx[1], idx[2]); self(idx) = lse_g_ms[i](idx[0] * G1Q + idx[1], idx[2]);
}); });
lse_tensors_device[i]->ToDevice(lse_tensors[i].data()); lse_tensors_device[i]->ToDevice(lse_tensors[i].data());
qgrad_tensors_device[i]->SetZero(); qgrad_tensors_device[i]->SetZero();
...@@ -724,13 +762,13 @@ int run(int argc, char* argv[]) ...@@ -724,13 +762,13 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int G0 = v_tensors[i].GetLengths()[0]; int G0 = q_tensors[i].GetLengths()[0];
int G1 = v_tensors[i].GetLengths()[1]; int G1Q = q_tensors[i].GetLengths()[1];
int O = v_tensors[i].GetLengths()[2]; int O = v_tensors[i].GetLengths()[2];
int N = v_tensors[i].GetLengths()[3]; int N = v_tensors[i].GetLengths()[3];
int M = q_tensors[i].GetLengths()[2]; int M = q_tensors[i].GetLengths()[2];
int K = q_tensors[i].GetLengths()[3]; int K = q_tensors[i].GetLengths()[3];
int BatchCount = G0 * G1; int BatchCount = G0 * G1Q;
Tensor<OutputDataType> qgrad_g_m_k({BatchCount, M, K}); Tensor<OutputDataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<OutputDataType> kgrad_g_n_k({BatchCount, N, K}); Tensor<OutputDataType> kgrad_g_n_k({BatchCount, N, K});
Tensor<OutputDataType> vgrad_g_n_o({BatchCount, N, O}); Tensor<OutputDataType> vgrad_g_n_o({BatchCount, N, O});
...@@ -740,7 +778,7 @@ int run(int argc, char* argv[]) ...@@ -740,7 +778,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> ygrad_g_m_o({BatchCount, M, O}); Tensor<InputDataType> ygrad_g_m_o({BatchCount, M, O});
ygrad_tensors[i].ForEach([&](auto& self, auto idx) { ygrad_tensors[i].ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); ygrad_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
auto ref_gemm0_grad = ReferenceGemm0GradInstance{}; auto ref_gemm0_grad = ReferenceGemm0GradInstance{};
auto ref_gemm0_grad_invoker = ref_gemm0_grad.MakeInvoker(); auto ref_gemm0_grad_invoker = ref_gemm0_grad.MakeInvoker();
...@@ -783,43 +821,43 @@ int run(int argc, char* argv[]) ...@@ -783,43 +821,43 @@ int run(int argc, char* argv[])
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(), Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(kgrad_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); kgrad_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_host_result(vgrad_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); vgrad_tensors[i].GetStrides());
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(), Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_tensors[i].GetLengths(),
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(kgrad_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); kgrad_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_device_result(vgrad_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); vgrad_tensors[i].GetStrides());
qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data()); qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data());
kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data()); kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data());
vgrad_tensors_device[i]->FromDevice(vgrad_gs_os_ns_device_result.data()); vgrad_tensors_device[i]->FromDevice(vgrad_gs_os_ns_device_result.data());
// permute // permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) { qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]); self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
}); });
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) { kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]); self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
}); });
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) { vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]); self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
}); });
......
...@@ -14,11 +14,12 @@ int run(int argc, char* argv[]) ...@@ -14,11 +14,12 @@ int run(int argc, char* argv[])
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape // Output shape C[G0, M, G1Q, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) // C_g0_g1q_m_o = reshape(C_g_m_o, [g0, g1q, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) // C_g0_m_g1q_o = permute(C_g0_g1q_m_o, [0, 2, 1, 3])
ck::index_t G0 = 7; ck::index_t G0 = 7;
ck::index_t G1 = 13; ck::index_t G1Q = 12; // h_q
ck::index_t G1KV = 12; // h_kv
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
...@@ -37,32 +38,33 @@ int run(int argc, char* argv[]) ...@@ -37,32 +38,33 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 14)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
K = std::stoi(argv[6]); K = std::stoi(argv[6]);
O = std::stoi(argv[7]); O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1Q = std::stoi(argv[9]);
G1KV = std::stoi(argv[10]);
p_drop = std::stof(argv[10]); p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[13]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4 to 10: M, N, K, O, G0, G1Q, G1KV\n");
printf("arg10: scale (alpha)\n"); printf("arg11: p_drop\n");
printf("arg11 to 12: input / output permute\n"); printf("arg12 to 13: input / output permute\n");
exit(0); exit(0);
} }
...@@ -71,39 +73,39 @@ int run(int argc, char* argv[]) ...@@ -71,39 +73,39 @@ int run(int argc, char* argv[])
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides = std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // A layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] : std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // A layout [G0, G1Q, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides = std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1} // B0 layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] : std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // B0 layout [G0, G1KV, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides = std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O} // B1 layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] : std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // B1 layout [G0, G1KV, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides = std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // C layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] : std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // C layout [G0, G1Q, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides = std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides = std::vector<ck::index_t> lse_gs_ms_strides =
std::vector<ck::index_t>{G1 * M, M, 1}; // LSE layout [G0, G1, M] std::vector<ck::index_t>{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
...@@ -211,7 +213,7 @@ int run(int argc, char* argv[]) ...@@ -211,7 +213,7 @@ int run(int argc, char* argv[])
return 0; return 0;
} }
ck::index_t BatchCount = G0 * G1; ck::index_t BatchCount = G0 * G1Q;
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
...@@ -276,24 +278,32 @@ int run(int argc, char* argv[]) ...@@ -276,24 +278,32 @@ int run(int argc, char* argv[])
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O}); Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); Tensor<ADataType> a1_g_m_n_drop({BatchCount, M, N});
Tensor<LSEDataType> lse_g_m_host_result( Tensor<LSEDataType> lse_g_m_host_result(
{BatchCount, M}); // scratch object after max + ln(sum) {BatchCount, M}); // scratch object after max + ln(sum)
Tensor<ZDataType> z_g_m_n({G0 * G1, M, N}); Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
// permute // permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) { a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); a_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { b0_g_k_n.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = b0_gs_ns_ks(g0, g1kv, idx[2], idx[1]);
}); });
b1_gs_os_ns.ForEach([&](auto& self, auto idx) { b1_g_n_o.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = b1_gs_os_ns(g0, g1kv, idx[2], idx[1]);
}); });
z_gs_ms_ns.ForEach([&](auto& self, auto idx) { z_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); z_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
// gemm 0 // gemm 0
...@@ -340,18 +350,18 @@ int run(int argc, char* argv[]) ...@@ -340,18 +350,18 @@ int run(int argc, char* argv[])
// permute // permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
}); });
lse_gs_ms_host_result.ForEach([&](auto& self, auto idx) { lse_gs_ms_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = lse_g_m_host_result(g, idx[2]); self(idx) = lse_g_m_host_result(g, idx[2]);
}); });
......
...@@ -11,6 +11,7 @@ int run(int argc, char* argv[]) ...@@ -11,6 +11,7 @@ int run(int argc, char* argv[])
bool output_permute = true; bool output_permute = true;
float p_drop = 0.2; float p_drop = 0.2;
int h_ratio = 1; // G1Q / G1KV
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -24,22 +25,25 @@ int run(int argc, char* argv[]) ...@@ -24,22 +25,25 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 7) else if(argc == 8)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
p_drop = std::stoi(argv[4]); p_drop = std::stoi(argv[4]);
input_permute = std::stoi(argv[5]); h_ratio = std::stof(argv[5]);
output_permute = std::stoi(argv[6]); input_permute = std::stoi(argv[6]);
output_permute = std::stoi(argv[7]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 5: input / output permute\n"); printf("arg4: p_drop\n");
printf("arg5: h_ratio\n");
printf("arg6 to 7: input / output permute\n");
exit(0); exit(0);
} }
...@@ -60,7 +64,7 @@ int run(int argc, char* argv[]) ...@@ -60,7 +64,7 @@ int run(int argc, char* argv[])
std::vector<void*> p_z; // for result verification std::vector<void*> p_z; // for result verification
std::vector<void*> p_z_nullptr; // for time test std::vector<void*> p_z_nullptr; // for time test
std::vector<void*> p_lse; std::vector<void*> p_lse;
std::vector<std::vector<int>> g0_g1_m_n_k_o; std::vector<std::vector<int>> g0_g1q_m_n_k_o;
std::vector<Tensor<ADataType>> a_tensors; std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<B0DataType>> b0_tensors; std::vector<Tensor<B0DataType>> b0_tensors;
...@@ -83,48 +87,51 @@ int run(int argc, char* argv[]) ...@@ -83,48 +87,51 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int M = 128 * (rand() % 8) + (rand() % 128); int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128); int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM; int K = DIM;
int O = DIM; int O = DIM;
int G0 = rand() % 3 + 1; int G0 = rand() % 3 + 1;
int G1 = rand() % 5 + 1; int G1KV = rand() % 5 + 1;
int G1Q = G1KV * h_ratio;
g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O}); g0_g1q_m_n_k_o.push_back({G0, G1Q, M, N, K, O});
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides = std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // A layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] : std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // A layout [G0, G1Q, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides = std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1}
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] // B0 layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // B0 layout [G0, G1KV, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides = std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O}
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] // B1 layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // B1 layout [G0, G1KV, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides = std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // C layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] : std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // C layout [G0, G1Q, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides = std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides = std::vector<ck::index_t> lse_gs_ms_strides =
std::vector<ck::index_t>{G1 * M, M, 1}; // LSE layout [G0, G1, M] std::vector<ck::index_t>{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
problem_descs.push_back({a_gs_ms_ks_lengths, problem_descs.push_back({a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
...@@ -151,7 +158,7 @@ int run(int argc, char* argv[]) ...@@ -151,7 +158,7 @@ int run(int argc, char* argv[])
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
int Batch = G0 * G1; int Batch = G0 * G1Q;
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
...@@ -308,12 +315,12 @@ int run(int argc, char* argv[]) ...@@ -308,12 +315,12 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
const int& G0 = g0_g1_m_n_k_o[i][0]; const int& G0 = g0_g1q_m_n_k_o[i][0];
const int& G1 = g0_g1_m_n_k_o[i][1]; const int& G1Q = g0_g1q_m_n_k_o[i][1];
const int& M = g0_g1_m_n_k_o[i][2]; const int& M = g0_g1q_m_n_k_o[i][2];
const int& N = g0_g1_m_n_k_o[i][3]; const int& N = g0_g1q_m_n_k_o[i][3];
const int& K = g0_g1_m_n_k_o[i][4]; const int& K = g0_g1q_m_n_k_o[i][4];
const int& O = g0_g1_m_n_k_o[i][5]; const int& O = g0_g1q_m_n_k_o[i][5];
const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths; const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths;
const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides; const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides;
...@@ -334,31 +341,39 @@ int run(int argc, char* argv[]) ...@@ -334,31 +341,39 @@ int run(int argc, char* argv[])
z_gs_ms_ns_device_buf.FromDevice(z_gs_ms_ns_device_result.mData.data()); z_gs_ms_ns_device_buf.FromDevice(z_gs_ms_ns_device_result.mData.data());
lse_gs_ms_device_buf.FromDevice(lse_gs_ms_device_result.mData.data()); lse_gs_ms_device_buf.FromDevice(lse_gs_ms_device_result.mData.data());
Tensor<ADataType> a_g_m_k({G0 * G1, M, K}); Tensor<ADataType> a_g_m_k({G0 * G1Q, M, K});
Tensor<B0DataType> b0_g_k_n({G0 * G1, K, N}); Tensor<B0DataType> b0_g_k_n({G0 * G1Q, K, N});
Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O}); Tensor<B1DataType> b1_g_n_o({G0 * G1Q, N, O});
Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0 Tensor<AccDataType> acc0_g_m_n({G0 * G1Q, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n({G0 * G1Q, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n_drop({G0 * G1Q, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({G0 * G1Q, M, O}); // scratch object after gemm1
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<ZDataType> z_g_m_n({G0 * G1, M, N}); Tensor<ZDataType> z_g_m_n({G0 * G1Q, M, N});
Tensor<LSEDataType> lse_g_m_host_result({G0 * G1, M}); // scratch object after gemm1 Tensor<LSEDataType> lse_g_m_host_result({G0 * G1Q, M}); // scratch object after gemm1
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
// permute // permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) { a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); a_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { b0_g_k_n.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = b0_gs_ns_ks(g0, g1kv, idx[2], idx[1]);
}); });
b1_gs_os_ns.ForEach([&](auto& self, auto idx) { b1_g_n_o.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = b1_gs_os_ns(g0, g1kv, idx[2], idx[1]);
}); });
z_gs_ms_ns_device_result.ForEach([&](auto& self, auto idx) { z_gs_ms_ns_device_result.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); z_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
// gemm 0 // gemm 0
...@@ -408,18 +423,18 @@ int run(int argc, char* argv[]) ...@@ -408,18 +423,18 @@ int run(int argc, char* argv[])
// permute // permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
}); });
lse_gs_ms_host_result.ForEach([&](auto& self, auto idx) { lse_gs_ms_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = lse_g_m_host_result(g, idx[2]); self(idx) = lse_g_m_host_result(g, idx[2]);
}); });
......
...@@ -273,14 +273,15 @@ int run(int argc, char* argv[]) ...@@ -273,14 +273,15 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape // Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
ck::index_t M = 512; ck::index_t M = 512;
ck::index_t N = 512; ck::index_t N = 512;
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
ck::index_t G0 = 4; ck::index_t G0 = 4;
ck::index_t G1 = 6; ck::index_t G1Q = 6; // h_q
ck::index_t G1KV = 6; // h_kv
bool input_permute = false; bool input_permute = false;
bool output_permute = false; bool output_permute = false;
...@@ -299,32 +300,33 @@ int run(int argc, char* argv[]) ...@@ -299,32 +300,33 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 14)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
K = std::stoi(argv[6]); K = std::stoi(argv[6]);
O = std::stoi(argv[7]); O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1Q = std::stoi(argv[9]);
G1KV = std::stoi(argv[10]);
p_drop = std::stof(argv[10]); p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[13]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4 to 10: M, N, K, O, G0, G1Q, G1KV\n");
printf("arg10: scale (alpha)\n"); printf("arg11: p_drop\n");
printf("arg11 to 12: input / output permute\n"); printf("arg12 to 13: input / output permute\n");
exit(0); exit(0);
} }
...@@ -341,7 +343,8 @@ int run(int argc, char* argv[]) ...@@ -341,7 +343,8 @@ int run(int argc, char* argv[])
std::cout << "K: " << K << std::endl; std::cout << "K: " << K << std::endl;
std::cout << "O: " << O << std::endl; std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl; std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl; std::cout << "G1Q: " << G1Q << std::endl;
std::cout << "G1KV: " << G1KV << std::endl;
std::cout << "alpha: " << alpha << std::endl; std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl; std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl; std::cout << "output_permute: " << output_permute << std::endl;
...@@ -349,51 +352,63 @@ int run(int argc, char* argv[]) ...@@ -349,51 +352,63 @@ int run(int argc, char* argv[])
std::cout << "seed: " << seed << std::endl; std::cout << "seed: " << seed << std::endl;
std::cout << "offset: " << offset << std::endl; std::cout << "offset: " << offset << std::endl;
const ck::index_t BatchCount = G0 * G1; const ck::index_t BatchCount = G0 * G1Q;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides = std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // Q layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K] : std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // Q layout [G0, G1Q, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides = std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // K layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1} // K layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // K layout [G0, G1, N, K] : std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // K layout [G0, G1KV, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides = std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // V layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O} // V layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // V layout [G0, G1, N, O] : std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // V layout [G0, G1KV, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides = std::vector<ck::index_t> y_gs_ms_os_strides =
output_permute output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // Y layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O] : std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // Y layout [G0, G1Q, M, O]
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_strides = std::vector<ck::index_t> d0_gs_ms_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // D layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // D layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // D layout [G0, G1, M, N] : std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // D layout [G0, G1Q, M, N]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides = std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1Q, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1Q * K, K, G1Q * K, 1} // KGrad layout [G0, N, G1Q, K]
: std::vector<ck::index_t>{G1Q * N * K, N * K, K, 1}; // KGrad layout [G0, G1Q, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1Q, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1Q * O, O, 1, G1Q * O} // VGrad layout [G0, N, G1Q, O]
: std::vector<ck::index_t>{G1Q * N * O, N * O, 1, O}; // VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass // The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...) // Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...))) // = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...))) // = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^ // ^^^^^^^^^^^^^^^^^^^^^
// LSE // LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M] std::vector<ck::index_t> lse_gs_ms_strides{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
...@@ -403,6 +418,8 @@ int run(int argc, char* argv[]) ...@@ -403,6 +418,8 @@ int run(int argc, char* argv[])
Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> y_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides); Tensor<InputDataType> ygrad_gs_ms_os(y_gs_ms_os_lengths, y_gs_ms_os_strides);
Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms(lse_gs_ms_lengths, lse_gs_ms_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks(kgrad_gs_ns_ks_lengths, kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns(vgrad_gs_os_ns_lengths, vgrad_gs_os_ns_strides);
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
...@@ -411,6 +428,8 @@ int run(int argc, char* argv[]) ...@@ -411,6 +428,8 @@ int run(int argc, char* argv[])
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
std::cout << "kgrad_gs_ns_ks: " << kgrad_gs_ns_ks.mDesc << std::endl;
std::cout << "vgrad_gs_os_ns: " << vgrad_gs_os_ns.mDesc << std::endl;
z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0}); z_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<InputDataType>{0});
switch(init_method) switch(init_method)
...@@ -448,7 +467,7 @@ int run(int argc, char* argv[]) ...@@ -448,7 +467,7 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1q, m, o]
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// dO dot O = [0; 1; 2; ...] // dO dot O = [0; 1; 2; ...]
break; break;
...@@ -456,7 +475,7 @@ int run(int argc, char* argv[]) ...@@ -456,7 +475,7 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1q, m, o]
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
...@@ -470,7 +489,8 @@ int run(int argc, char* argv[]) ...@@ -470,7 +489,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{}); v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0,g1, m, o] ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1q, m, o]
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256 // assume mnko = 256
// P = softmax(QK) = 0.0039 * ones // P = softmax(QK) = 0.0039 * ones
...@@ -491,8 +511,8 @@ int run(int argc, char* argv[]) ...@@ -491,8 +511,8 @@ int run(int argc, char* argv[])
DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem y_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize()); DeviceMem lse_device_buf(sizeof(LSEDataType) * lse_gs_ms.mDesc.GetElementSpaceSize());
DeviceMem qgrad_device_buf(sizeof(OutputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize()); DeviceMem qgrad_device_buf(sizeof(OutputDataType) * q_gs_ms_ks.mDesc.GetElementSpaceSize());
DeviceMem kgrad_device_buf(sizeof(OutputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem kgrad_device_buf(sizeof(OutputDataType) * kgrad_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(OutputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem vgrad_device_buf(sizeof(OutputDataType) * vgrad_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem d0grad_device_buf(sizeof(Acc0BiasDataType) * d0_gs_ms_ns.mDesc.GetElementSpaceSize()); DeviceMem d0grad_device_buf(sizeof(Acc0BiasDataType) * d0_gs_ms_ns.mDesc.GetElementSpaceSize());
...@@ -533,6 +553,10 @@ int run(int argc, char* argv[]) ...@@ -533,6 +553,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_bias_gs_ms_os_lengths, {}, // acc1_bias_gs_ms_os_lengths,
...@@ -580,6 +604,10 @@ int run(int argc, char* argv[]) ...@@ -580,6 +604,10 @@ int run(int argc, char* argv[])
y_gs_ms_os_lengths, y_gs_ms_os_lengths,
y_gs_ms_os_strides, y_gs_ms_os_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
kgrad_gs_ns_ks_lengths,
kgrad_gs_ns_ks_strides,
vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides,
d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths d0_gs_ms_ns_lengths, // acc0_bias_gs_ms_ns_lengths
d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides d0_gs_ms_ns_strides, // acc0_bias_gs_ms_ns_strides
{}, // acc1_bias_gs_ms_os_lengths, {}, // acc1_bias_gs_ms_os_lengths,
...@@ -624,7 +652,7 @@ int run(int argc, char* argv[]) ...@@ -624,7 +652,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> q_g_m_k({BatchCount, M, K}); Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<InputDataType> k_g_n_k({BatchCount, N, K}); Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<Acc0BiasDataType> d0_g_m_n({G0 * G1, M, N}); Tensor<Acc0BiasDataType> d0_g_m_n({BatchCount, M, N});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N}); Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<InputDataType> v_g_n_o({BatchCount, N, O}); Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N}); Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
...@@ -635,19 +663,27 @@ int run(int argc, char* argv[]) ...@@ -635,19 +663,27 @@ int run(int argc, char* argv[])
z_device_buf.FromDevice(z_gs_ms_ns.mData.data()); z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
z_gs_ms_ns.ForEach([&](auto& self, auto idx) { z_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); z_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
q_gs_ms_ks.ForEach([&](auto& self, auto idx) { q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); q_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
k_gs_ns_ks.ForEach([&](auto& self, auto idx) { k_g_n_k.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = k_gs_ns_ks(g0, g1kv, idx[1], idx[2]);
}); });
v_gs_os_ns.ForEach([&](auto& self, auto idx) { v_g_n_o.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = v_gs_os_ns(g0, g1kv, idx[2], idx[1]);
}); });
d0_gs_ms_ns.ForEach([&](auto& self, auto idx) { d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); d0_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
// run fwd again for y, cause z_g_m_n update // run fwd again for y, cause z_g_m_n update
run_attention_fwd_host(q_g_m_k, run_attention_fwd_host(q_g_m_k,
...@@ -664,10 +700,10 @@ int run(int argc, char* argv[]) ...@@ -664,10 +700,10 @@ int run(int argc, char* argv[])
p_dropout_in_uint8_t, p_dropout_in_uint8_t,
rp_dropout); rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) { y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]); self(idx) = y_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]);
}); });
lse_gs_ms.ForEach( lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); }); [&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1Q + idx[1], idx[2]); });
y_device_buf.ToDevice(y_gs_ms_os.mData.data()); y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data()); lse_device_buf.ToDevice(lse_gs_ms.mData.data());
...@@ -685,7 +721,7 @@ int run(int argc, char* argv[]) ...@@ -685,7 +721,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> ygrad_dot_y_g_m({BatchCount, M}); Tensor<InputDataType> ygrad_dot_y_g_m({BatchCount, M});
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) { ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); ygrad_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
#if PRINT_HOST #if PRINT_HOST
...@@ -787,14 +823,18 @@ int run(int argc, char* argv[]) ...@@ -787,14 +823,18 @@ int run(int argc, char* argv[])
#endif #endif
Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_host_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(kgrad_gs_ns_ks_lengths,
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides);
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_host_result(d0_gs_ms_ns_lengths, Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_host_result(d0_gs_ms_ns_lengths,
d0_gs_ms_ns_strides); d0_gs_ms_ns_strides);
Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<OutputDataType> qgrad_gs_ms_ks_device_result(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(kgrad_gs_ns_ks_lengths,
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_gs_os_ns_lengths, v_gs_os_ns_strides); kgrad_gs_ns_ks_strides);
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(vgrad_gs_os_ns_lengths,
vgrad_gs_os_ns_strides);
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_device_result(d0_gs_ms_ns_lengths, Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_device_result(d0_gs_ms_ns_lengths,
d0_gs_ms_ns_strides); d0_gs_ms_ns_strides);
...@@ -805,35 +845,35 @@ int run(int argc, char* argv[]) ...@@ -805,35 +845,35 @@ int run(int argc, char* argv[])
// permute // permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) { qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]); self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
}); });
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) { kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]); self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
}); });
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) { vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]); self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
}); });
d0grad_gs_ms_ns_host_result.ForEach([&](auto& self, auto idx) { d0grad_gs_ms_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = sgrad_g_m_n(g, idx[2], idx[3]); self(idx) = sgrad_g_m_n(g, idx[2], idx[3]);
}); });
......
...@@ -14,11 +14,12 @@ int run(int argc, char* argv[]) ...@@ -14,11 +14,12 @@ int run(int argc, char* argv[])
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape // Output shape C[G0, M, G1Q, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) // C_g0_g1q_m_o = reshape(C_g_m_o, [g0, g1q, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) // C_g0_m_g1q_o = permute(C_g0_g1q_m_o, [0, 2, 1, 3])
ck::index_t G0 = 7; ck::index_t G0 = 7;
ck::index_t G1 = 13; ck::index_t G1Q = 12; // h_q
ck::index_t G1KV = 12; // h_kv
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
...@@ -37,32 +38,33 @@ int run(int argc, char* argv[]) ...@@ -37,32 +38,33 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 14)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
K = std::stoi(argv[6]); K = std::stoi(argv[6]);
O = std::stoi(argv[7]); O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1Q = std::stoi(argv[9]);
G1KV = std::stoi(argv[10]);
p_drop = std::stof(argv[10]); p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[13]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4 to 10: M, N, K, O, G0, G1Q, G1KV\n");
printf("arg10: scale (alpha)\n"); printf("arg11: p_drop\n");
printf("arg11 to 12: input / output permute\n"); printf("arg12 to 13: input / output permute\n");
exit(0); exit(0);
} }
...@@ -71,45 +73,45 @@ int run(int argc, char* argv[]) ...@@ -71,45 +73,45 @@ int run(int argc, char* argv[])
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K); float alpha = 1.f / std::sqrt(K);
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides = std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // A layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] : std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // A layout [G0, G1Q, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides = std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1} // B0 layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] : std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // B0 layout [G0, G1KV, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides = std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O} // B1 layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] : std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // B1 layout [G0, G1KV, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides = std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // C layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] : std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // C layout [G0, G1Q, M, O]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> d_gs_ms_ns_strides = std::vector<ck::index_t> d_gs_ms_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // D layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // D layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // D layout [G0, G1, M, N] : std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // D layout [G0, G1Q, M, N]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides = std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides = std::vector<ck::index_t> lse_gs_ms_strides =
std::vector<ck::index_t>{G1 * M, M, 1}; // LSE layout [G0, G1, M] std::vector<ck::index_t>{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
...@@ -224,7 +226,7 @@ int run(int argc, char* argv[]) ...@@ -224,7 +226,7 @@ int run(int argc, char* argv[])
return 0; return 0;
} }
ck::index_t BatchCount = G0 * G1; ck::index_t BatchCount = G0 * G1Q;
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
...@@ -312,29 +314,37 @@ int run(int argc, char* argv[]) ...@@ -312,29 +314,37 @@ int run(int argc, char* argv[])
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O}); Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); Tensor<ADataType> a1_g_m_n_drop({BatchCount, M, N});
Tensor<LSEDataType> lse_g_m_host_result( Tensor<LSEDataType> lse_g_m_host_result(
{BatchCount, M}); // scratch object after max + ln(sum) {BatchCount, M}); // scratch object after max + ln(sum)
Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N}); Tensor<Acc0BiasDataType> d_g_m_n({BatchCount, M, N});
Tensor<ZDataType> z_g_m_n({G0 * G1, M, N}); Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
// permute // permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) { a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); a_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { b0_g_k_n.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = b0_gs_ns_ks(g0, g1kv, idx[2], idx[1]);
}); });
b1_gs_os_ns.ForEach([&](auto& self, auto idx) { b1_g_n_o.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = b1_gs_os_ns(g0, g1kv, idx[2], idx[1]);
}); });
d_gs_ms_ns.ForEach([&](auto& self, auto idx) { d_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); d_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
z_gs_ms_ns.ForEach([&](auto& self, auto idx) { z_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); z_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
// gemm 0 // gemm 0
...@@ -384,18 +394,18 @@ int run(int argc, char* argv[]) ...@@ -384,18 +394,18 @@ int run(int argc, char* argv[])
// permute // permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
}); });
lse_gs_ms_host_result.ForEach([&](auto& self, auto idx) { lse_gs_ms_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = lse_g_m_host_result(g, idx[2]); self(idx) = lse_g_m_host_result(g, idx[2]);
}); });
......
...@@ -11,6 +11,7 @@ int run(int argc, char* argv[]) ...@@ -11,6 +11,7 @@ int run(int argc, char* argv[])
bool output_permute = true; bool output_permute = true;
float p_drop = 0.2; float p_drop = 0.2;
int h_ratio = 1; // G1Q / G1KV
const unsigned long long seed = 1; const unsigned long long seed = 1;
const unsigned long long offset = 0; const unsigned long long offset = 0;
...@@ -24,22 +25,25 @@ int run(int argc, char* argv[]) ...@@ -24,22 +25,25 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 7) else if(argc == 8)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
p_drop = std::stoi(argv[4]); p_drop = std::stoi(argv[4]);
input_permute = std::stoi(argv[5]); h_ratio = std::stof(argv[5]);
output_permute = std::stoi(argv[6]); input_permute = std::stoi(argv[6]);
output_permute = std::stoi(argv[7]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 5: input / output permute\n"); printf("arg4: p_drop\n");
printf("arg5: h_ratio\n");
printf("arg6 to 7: input / output permute\n");
exit(0); exit(0);
} }
...@@ -61,7 +65,7 @@ int run(int argc, char* argv[]) ...@@ -61,7 +65,7 @@ int run(int argc, char* argv[])
std::vector<void*> p_z; // for result verification std::vector<void*> p_z; // for result verification
std::vector<void*> p_z_nullptr; // for time test std::vector<void*> p_z_nullptr; // for time test
std::vector<void*> p_lse; std::vector<void*> p_lse;
std::vector<std::vector<int>> g0_g1_m_n_k_o; std::vector<std::vector<int>> g0_g1q_m_n_k_o;
std::vector<Tensor<ADataType>> a_tensors; std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<B0DataType>> b0_tensors; std::vector<Tensor<B0DataType>> b0_tensors;
...@@ -86,54 +90,57 @@ int run(int argc, char* argv[]) ...@@ -86,54 +90,57 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int M = 128 * (rand() % 8) + (rand() % 128); int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128); int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM; int K = DIM;
int O = DIM; int O = DIM;
int G0 = rand() % 3 + 1; int G0 = rand() % 3 + 1;
int G1 = rand() % 5 + 1; int G1KV = rand() % 5 + 1;
int G1Q = G1KV * h_ratio;
g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O}); g0_g1q_m_n_k_o.push_back({G0, G1Q, M, N, K, O});
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides = std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // A layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] : std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // A layout [G0, G1Q, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides = std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1}
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] // B0 layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // B0 layout [G0, G1KV, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides = std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O}
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] // B1 layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // B1 layout [G0, G1KV, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides = std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // C layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] : std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // C layout [G0, G1Q, M, O]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> d_gs_ms_ns_strides = std::vector<ck::index_t> d_gs_ms_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // D layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // D layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // D layout [G0, G1, M, N] : std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // D layout [G0, G1Q, M, N]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides = std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N] : std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M}; std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides = std::vector<ck::index_t> lse_gs_ms_strides =
std::vector<ck::index_t>{G1 * M, M, 1}; // LSE layout [G0, G1, M] std::vector<ck::index_t>{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
problem_descs.push_back({a_gs_ms_ks_lengths, problem_descs.push_back({a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
...@@ -161,7 +168,7 @@ int run(int argc, char* argv[]) ...@@ -161,7 +168,7 @@ int run(int argc, char* argv[])
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides); Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
int Batch = G0 * G1; int Batch = G0 * G1Q;
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
num_byte += num_byte +=
(sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O + (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O +
...@@ -351,12 +358,12 @@ int run(int argc, char* argv[]) ...@@ -351,12 +358,12 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
const int& G0 = g0_g1_m_n_k_o[i][0]; const int& G0 = g0_g1q_m_n_k_o[i][0];
const int& G1 = g0_g1_m_n_k_o[i][1]; const int& G1Q = g0_g1q_m_n_k_o[i][1];
const int& M = g0_g1_m_n_k_o[i][2]; const int& M = g0_g1q_m_n_k_o[i][2];
const int& N = g0_g1_m_n_k_o[i][3]; const int& N = g0_g1q_m_n_k_o[i][3];
const int& K = g0_g1_m_n_k_o[i][4]; const int& K = g0_g1q_m_n_k_o[i][4];
const int& O = g0_g1_m_n_k_o[i][5]; const int& O = g0_g1q_m_n_k_o[i][5];
const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths; const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths;
const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides; const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides;
...@@ -378,36 +385,43 @@ int run(int argc, char* argv[]) ...@@ -378,36 +385,43 @@ int run(int argc, char* argv[])
z_gs_ms_ns_device_buf.FromDevice(z_gs_ms_ns_device_result.mData.data()); z_gs_ms_ns_device_buf.FromDevice(z_gs_ms_ns_device_result.mData.data());
lse_gs_ms_device_buf.FromDevice(lse_gs_ms_device_result.mData.data()); lse_gs_ms_device_buf.FromDevice(lse_gs_ms_device_result.mData.data());
Tensor<ADataType> a_g_m_k({G0 * G1, M, K}); Tensor<ADataType> a_g_m_k({G0 * G1Q, M, K});
Tensor<B0DataType> b0_g_k_n({G0 * G1, K, N}); Tensor<B0DataType> b0_g_k_n({G0 * G1Q, K, N});
Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O}); Tensor<B1DataType> b1_g_n_o({G0 * G1Q, N, O});
Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0 Tensor<AccDataType> acc0_g_m_n({G0 * G1Q, M, N}); // scratch object after gemm0
Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N}); Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1Q, M, N});
Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n({G0 * G1Q, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n_drop({G0 * G1Q, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({G0 * G1Q, M, O}); // scratch object after gemm1
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<ZDataType> z_g_m_n({G0 * G1, M, N}); Tensor<ZDataType> z_g_m_n({G0 * G1Q, M, N});
Tensor<LSEDataType> lse_g_m_host_result({G0 * G1, M}); // scratch object after gemm1 Tensor<LSEDataType> lse_g_m_host_result({G0 * G1Q, M}); // scratch object after gemm1
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides); Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
// permute // permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) { a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); a_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { b0_g_k_n.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1Q;
}); const size_t& g1q = idx[0] % G1Q;
b1_gs_os_ns.ForEach([&](auto& self, auto idx) { const size_t& g1kv = g1q / h_ratio;
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
self(idx) = b0_gs_ns_ks(g0, g1kv, idx[2], idx[1]);
}); });
b1_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = b1_gs_os_ns(g0, g1kv, idx[2], idx[1]);
});
d_gs_ms_ns.ForEach([&](auto& self, auto idx) { d_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); d_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
z_gs_ms_ns_device_result.ForEach([&](auto& self, auto idx) { z_gs_ms_ns_device_result.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); z_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
// gemm 0 // gemm 0
...@@ -461,18 +475,18 @@ int run(int argc, char* argv[]) ...@@ -461,18 +475,18 @@ int run(int argc, char* argv[])
// permute // permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
}); });
lse_gs_ms_host_result.ForEach([&](auto& self, auto idx) { lse_gs_ms_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = lse_g_m_host_result(g, idx[2]); self(idx) = lse_g_m_host_result(g, idx[2]);
}); });
......
...@@ -132,14 +132,17 @@ __global__ void ...@@ -132,14 +132,17 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const BGridDesc_BK0_N_BK1 bgrad_grid_desc_bk0_n_bk1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3, const D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1grad_grid_desc_bk0_n_bk1,
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
const YGradGridDesc_O0_M_O1 ygrad_grid_desc_o0_m_o1, const YGradGridDesc_O0_M_O1 ygrad_grid_desc_o0_m_o1,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const index_t h_ratio,
const index_t nblock, const index_t nblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
...@@ -155,21 +158,26 @@ __global__ void ...@@ -155,21 +158,26 @@ __global__ void
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio);
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch // NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch
// offsets // offsets
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(gkv_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(gkv_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
const long_index_t bgrad_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBGradBasePtr(g_idx)));
const long_index_t b1grad_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1GradBasePtr(g_idx)));
ck::philox ph(seed, 0, offset); ck::philox ph(seed, 0, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset); ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
...@@ -205,9 +213,9 @@ __global__ void ...@@ -205,9 +213,9 @@ __global__ void
p_d_grid + lse_batch_offset, p_d_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + bgrad_batch_offset,
tmp_p_d0grad_grid, tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1grad_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -216,9 +224,11 @@ __global__ void ...@@ -216,9 +224,11 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
bgrad_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
b1grad_grid_desc_bk0_n_bk1,
lse_grid_desc_m, lse_grid_desc_m,
ygrad_grid_desc_o0_m_o1, ygrad_grid_desc_o0_m_o1,
block_2_ctile_map, block_2_ctile_map,
...@@ -242,9 +252,9 @@ __global__ void ...@@ -242,9 +252,9 @@ __global__ void
p_d_grid + lse_batch_offset, p_d_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + bgrad_batch_offset,
tmp_p_d0grad_grid, tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1grad_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -253,9 +263,11 @@ __global__ void ...@@ -253,9 +263,11 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
bgrad_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
b1grad_grid_desc_bk0_n_bk1,
lse_grid_desc_m, lse_grid_desc_m,
ygrad_grid_desc_o0_m_o1, ygrad_grid_desc_o0_m_o1,
block_2_ctile_map, block_2_ctile_map,
...@@ -286,13 +298,16 @@ __global__ void ...@@ -286,13 +298,16 @@ __global__ void
ignore = c_element_op; ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = bgrad_grid_desc_bk0_n_bk1;
ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3; ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3;
ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3; ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
ignore = b1_grid_desc_bk0_n_bk1; ignore = b1_grid_desc_bk0_n_bk1;
ignore = b1grad_grid_desc_bk0_n_bk1;
ignore = lse_grid_desc_m; ignore = lse_grid_desc_m;
ignore = ygrad_grid_desc_o0_m_o1; ignore = ygrad_grid_desc_o0_m_o1;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = h_ratio;
ignore = nblock; ignore = nblock;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask; ignore = c0_matrix_mask;
...@@ -695,6 +710,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -695,6 +710,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const BGridDesc_G_N_K& bgrad_grid_desc_g_n_k,
const B1GridDesc_G_N_K& b1grad_grid_desc_g_n_k,
index_t BatchStrideLSE) index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
...@@ -702,6 +719,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -702,6 +719,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
bgrad_grid_desc_g_n_k_(bgrad_grid_desc_g_n_k),
b1grad_grid_desc_g_n_k_(b1grad_grid_desc_g_n_k),
BatchStrideLSE_(BatchStrideLSE) BatchStrideLSE_(BatchStrideLSE)
{ {
} }
...@@ -741,6 +760,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -741,6 +760,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
return g_idx * static_cast<long_index_t>(BatchStrideLSE_); return g_idx * static_cast<long_index_t>(BatchStrideLSE_);
} }
__host__ __device__ constexpr long_index_t GetBGradBasePtr(index_t g_idx) const
{
return bgrad_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB1GradBasePtr(index_t g_idx) const
{
return b1grad_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
...@@ -748,6 +777,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -748,6 +777,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
BGridDesc_G_N_K bgrad_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1grad_grid_desc_g_n_k_;
index_t BatchStrideLSE_; index_t BatchStrideLSE_;
}; };
...@@ -858,6 +889,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -858,6 +889,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_strides,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
...@@ -888,9 +923,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -888,9 +923,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
bgrad_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
bgrad_gs_ns_ks_lengths, bgrad_gs_ns_ks_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1( b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
b1grad_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1(
b1grad_gs_gemm1ns_gemm1ks_lengths, b1grad_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
d_y_grid_desc_m_o_{DTransform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, d_y_grid_desc_m_o_{DTransform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
...@@ -912,6 +951,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -912,6 +951,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{ z_grid_desc_g_m_n_{
Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
bgrad_grid_desc_g_n_k_{Transform::MakeB0GridDescriptor_G_N_K(bgrad_gs_ns_ks_lengths,
bgrad_gs_ns_ks_strides)},
b1grad_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1grad_gs_gemm1ns_gemm1ks_lengths, b1grad_gs_gemm1ns_gemm1ks_strides)},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
d_block_2_ctile_map_{ d_block_2_ctile_map_{
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(d_y_grid_desc_m_o_)}, GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(d_y_grid_desc_m_o_)},
...@@ -935,6 +978,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -935,6 +978,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
h_ratio_{c_grid_desc_g_m_n_.GetLength(I0) / b_grid_desc_g_n_k_.GetLength(I0)},
p_drop_{p_drop} p_drop_{p_drop}
{ {
// TODO: implement bias addition // TODO: implement bias addition
...@@ -964,6 +1008,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -964,6 +1008,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
z_grid_desc_g_m_n_, z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_, b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_, c_grid_desc_g_m_n_,
bgrad_grid_desc_g_n_k_,
b1grad_grid_desc_g_n_k_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())); type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize()));
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
...@@ -990,7 +1036,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -990,7 +1036,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
<< b_grid_desc_g_n_k_.GetLength(I1) << ", " << b_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b_grid_desc_g_n_k_.GetLength(I2) << '\n'; << b_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b_grid_desc_g_n_k_.Print(); // b_grid_desc_g_n_k_.Print();
std::cout << "b1_grid_desc_g_o_n_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", " std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", " << b1_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n'; << b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b1_grid_desc_g_n_k_.Print(); // b1_grid_desc_g_n_k_.Print();
...@@ -1003,6 +1049,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1003,6 +1049,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::cout << "ygrad_grid_desc_o0_m_o1_: " << ygrad_grid_desc_o0_m_o1_.GetLength(I0) std::cout << "ygrad_grid_desc_o0_m_o1_: " << ygrad_grid_desc_o0_m_o1_.GetLength(I0)
<< ", " << ygrad_grid_desc_o0_m_o1_.GetLength(I1) << ", " << ", " << ygrad_grid_desc_o0_m_o1_.GetLength(I1) << ", "
<< ygrad_grid_desc_o0_m_o1_.GetLength(I2) << '\n'; << ygrad_grid_desc_o0_m_o1_.GetLength(I2) << '\n';
std::cout << "d0_grid_desc_g_m_n_: " << d0_grid_desc_g_m_n_.GetLength(I0) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I1) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I2) << '\n';
std::cout << "bgrad_grid_desc_g_n_k_: " << bgrad_grid_desc_g_n_k_.GetLength(I0) << ", "
<< bgrad_grid_desc_g_n_k_.GetLength(I1) << ", "
<< bgrad_grid_desc_g_n_k_.GetLength(I2) << '\n';
// bgrad_grid_desc_g_n_k_.Print();
std::cout << "b1grad_grid_desc_g_n_k_: " << b1grad_grid_desc_g_n_k_.GetLength(I0)
<< ", " << b1grad_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1grad_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b1grad_grid_desc_g_n_k_.Print();
} }
// pointers // pointers
...@@ -1023,9 +1080,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1023,9 +1080,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
BGridDesc_BK0_N_BK1 bgrad_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_; typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1grad_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
DYGridDesc_M_O d_y_grid_desc_m_o_; DYGridDesc_M_O d_y_grid_desc_m_o_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
...@@ -1040,6 +1099,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1040,6 +1099,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
BGridDesc_G_N_K bgrad_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1grad_grid_desc_g_n_k_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_; c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
...@@ -1068,6 +1129,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1068,6 +1129,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::vector<index_t> c_mz_gemm1nz_strides_; std::vector<index_t> c_mz_gemm1nz_strides_;
index_t batch_count_; index_t batch_count_;
index_t h_ratio_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_drop_; float p_drop_;
...@@ -1189,13 +1251,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1189,13 +1251,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.bgrad_grid_desc_bk0_n_bk1_,
arg.d0_grid_desc_m0_n0_m1_m2_n1_m3_, arg.d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.b1grad_grid_desc_bk0_n_bk1_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.ygrad_grid_desc_o0_m_o1_, arg.ygrad_grid_desc_o0_m_o1_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.h_ratio_,
arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_), arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_),
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
...@@ -1249,13 +1314,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1249,13 +1314,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t b_g = arg.b_grid_desc_g_n_k_.GetLength(I0);
const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = arg.y_grid_desc_m_o_.GetLength(I1); const index_t c_gemm1n = arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = const index_t b1_gemm1n =
arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2); arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0))
{ {
return false; return false;
} }
...@@ -1355,6 +1421,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1355,6 +1421,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_strides,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
...@@ -1395,6 +1465,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1395,6 +1465,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
bgrad_gs_ns_ks_lengths,
bgrad_gs_ns_ks_strides,
b1grad_gs_gemm1ns_gemm1ks_lengths,
b1grad_gs_gemm1ns_gemm1ks_strides,
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
...@@ -1439,6 +1513,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1439,6 +1513,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_strides,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
...@@ -1467,8 +1545,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1467,8 +1545,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<D0DataType*>(p_d0grad_grid), static_cast<const D0DataType*>(p_d0grad_grid),
static_cast<D1DataType*>(p_d1grad_grid), static_cast<const D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1480,6 +1558,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1480,6 +1558,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
bgrad_gs_ns_ks_lengths,
bgrad_gs_ns_ks_strides,
b1grad_gs_gemm1ns_gemm1ks_lengths,
b1grad_gs_gemm1ns_gemm1ks_strides,
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_bias_gs_ms_gemm1ns_lengths, acc1_bias_gs_ms_gemm1ns_lengths,
......
...@@ -132,14 +132,17 @@ __global__ void ...@@ -132,14 +132,17 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const BGridDesc_BK0_N_BK1 bgrad_grid_desc_bk0_n_bk1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3, const D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1grad_grid_desc_bk0_n_bk1,
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
const YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1, const YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const index_t h_ratio,
const index_t nblock, const index_t nblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
...@@ -155,21 +158,26 @@ __global__ void ...@@ -155,21 +158,26 @@ __global__ void
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio);
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch // NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch
// offsets // offsets
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(gkv_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(gkv_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
const long_index_t bgrad_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBGradBasePtr(g_idx)));
const long_index_t b1grad_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1GradBasePtr(g_idx)));
ck::philox ph(seed, 0, offset); ck::philox ph(seed, 0, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset); ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
...@@ -206,9 +214,9 @@ __global__ void ...@@ -206,9 +214,9 @@ __global__ void
p_d_grid + lse_batch_offset, p_d_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + bgrad_batch_offset,
tmp_p_d0grad_grid, tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1grad_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -217,9 +225,11 @@ __global__ void ...@@ -217,9 +225,11 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
bgrad_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
b1grad_grid_desc_bk0_n_bk1,
lse_grid_desc_m, lse_grid_desc_m,
ygrad_grid_desc_m0_o_m1, ygrad_grid_desc_m0_o_m1,
block_2_ctile_map, block_2_ctile_map,
...@@ -243,9 +253,9 @@ __global__ void ...@@ -243,9 +253,9 @@ __global__ void
p_d_grid + lse_batch_offset, p_d_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + bgrad_batch_offset,
tmp_p_d0grad_grid, tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1grad_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -254,9 +264,11 @@ __global__ void ...@@ -254,9 +264,11 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
bgrad_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
b1grad_grid_desc_bk0_n_bk1,
lse_grid_desc_m, lse_grid_desc_m,
ygrad_grid_desc_m0_o_m1, ygrad_grid_desc_m0_o_m1,
block_2_ctile_map, block_2_ctile_map,
...@@ -287,13 +299,16 @@ __global__ void ...@@ -287,13 +299,16 @@ __global__ void
ignore = c_element_op; ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = bgrad_grid_desc_bk0_n_bk1;
ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3; ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3;
ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3; ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
ignore = b1_grid_desc_bk0_n_bk1; ignore = b1_grid_desc_bk0_n_bk1;
ignore = b1grad_grid_desc_bk0_n_bk1;
ignore = lse_grid_desc_m; ignore = lse_grid_desc_m;
ignore = ygrad_grid_desc_m0_o_m1; ignore = ygrad_grid_desc_m0_o_m1;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = h_ratio;
ignore = nblock; ignore = nblock;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask; ignore = c0_matrix_mask;
...@@ -704,6 +719,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -704,6 +719,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const BGridDesc_G_N_K& bgrad_grid_desc_g_n_k,
const B1GridDesc_G_N_K& b1grad_grid_desc_g_n_k,
index_t BatchStrideLSE) index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
...@@ -711,6 +728,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -711,6 +728,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
bgrad_grid_desc_g_n_k_(bgrad_grid_desc_g_n_k),
b1grad_grid_desc_g_n_k_(b1grad_grid_desc_g_n_k),
BatchStrideLSE_(BatchStrideLSE) BatchStrideLSE_(BatchStrideLSE)
{ {
} }
...@@ -729,6 +748,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -729,6 +748,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{ {
return d0_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return d0_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{ {
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
...@@ -749,6 +769,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -749,6 +769,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
return g_idx * static_cast<long_index_t>(BatchStrideLSE_); return g_idx * static_cast<long_index_t>(BatchStrideLSE_);
} }
__host__ __device__ constexpr long_index_t GetBGradBasePtr(index_t g_idx) const
{
return bgrad_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB1GradBasePtr(index_t g_idx) const
{
return b1grad_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
...@@ -756,6 +786,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -756,6 +786,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
BGridDesc_G_N_K bgrad_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1grad_grid_desc_g_n_k_;
index_t BatchStrideLSE_; index_t BatchStrideLSE_;
}; };
...@@ -874,6 +906,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -874,6 +906,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_strides,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
...@@ -904,9 +940,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -904,9 +940,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
bgrad_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
bgrad_gs_ns_ks_lengths, bgrad_gs_ns_ks_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1( b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
b1grad_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1(
b1grad_gs_gemm1ns_gemm1ks_lengths, b1grad_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
d_y_grid_desc_m_o_{DTransform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, d_y_grid_desc_m_o_{DTransform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
...@@ -927,6 +967,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -927,6 +967,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{ z_grid_desc_g_m_n_{
Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
bgrad_grid_desc_g_n_k_{Transform::MakeB0GridDescriptor_G_N_K(bgrad_gs_ns_ks_lengths,
bgrad_gs_ns_ks_strides)},
b1grad_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1grad_gs_gemm1ns_gemm1ks_lengths, b1grad_gs_gemm1ns_gemm1ks_strides)},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
d_block_2_ctile_map_{ d_block_2_ctile_map_{
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(d_y_grid_desc_m_o_)}, GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(d_y_grid_desc_m_o_)},
...@@ -950,6 +994,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -950,6 +994,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
h_ratio_{c_grid_desc_g_m_n_.GetLength(I0) / b_grid_desc_g_n_k_.GetLength(I0)},
p_drop_{p_drop} p_drop_{p_drop}
{ {
// TODO: implement bias addition // TODO: implement bias addition
...@@ -979,6 +1024,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -979,6 +1024,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
z_grid_desc_g_m_n_, z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_, b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_, c_grid_desc_g_m_n_,
bgrad_grid_desc_g_n_k_,
b1grad_grid_desc_g_n_k_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())); type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize()));
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
...@@ -1005,7 +1052,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1005,7 +1052,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
<< b_grid_desc_g_n_k_.GetLength(I1) << ", " << b_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b_grid_desc_g_n_k_.GetLength(I2) << '\n'; << b_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b_grid_desc_g_n_k_.Print(); // b_grid_desc_g_n_k_.Print();
std::cout << "b1_grid_desc_g_o_n_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", " std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", " << b1_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n'; << b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b1_grid_desc_g_n_k_.Print(); // b1_grid_desc_g_n_k_.Print();
...@@ -1018,6 +1065,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1018,6 +1065,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0) std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0)
<< ", " << ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", " << ", " << ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", "
<< ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n'; << ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n';
std::cout << "d0_grid_desc_g_m_n_: " << d0_grid_desc_g_m_n_.GetLength(I0) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I1) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I2) << '\n';
std::cout << "bgrad_grid_desc_g_n_k_: " << bgrad_grid_desc_g_n_k_.GetLength(I0) << ", "
<< bgrad_grid_desc_g_n_k_.GetLength(I1) << ", "
<< bgrad_grid_desc_g_n_k_.GetLength(I2) << '\n';
// bgrad_grid_desc_g_n_k_.Print();
std::cout << "b1grad_grid_desc_g_n_k_: " << b1grad_grid_desc_g_n_k_.GetLength(I0)
<< ", " << b1grad_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1grad_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b1grad_grid_desc_g_n_k_.Print();
} }
// pointers // pointers
...@@ -1038,9 +1096,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1038,9 +1096,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
BGridDesc_BK0_N_BK1 bgrad_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_; typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1grad_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
DYGridDesc_M_O d_y_grid_desc_m_o_; DYGridDesc_M_O d_y_grid_desc_m_o_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
...@@ -1055,6 +1115,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1055,6 +1115,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
BGridDesc_G_N_K bgrad_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1grad_grid_desc_g_n_k_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_; c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
...@@ -1083,6 +1145,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1083,6 +1145,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::vector<index_t> c_mz_gemm1nz_strides_; std::vector<index_t> c_mz_gemm1nz_strides_;
index_t batch_count_; index_t batch_count_;
index_t h_ratio_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_drop_; float p_drop_;
...@@ -1208,13 +1271,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1208,13 +1271,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.bgrad_grid_desc_bk0_n_bk1_,
arg.d0_grid_desc_m0_n0_m1_m2_n1_m3_, arg.d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.b1grad_grid_desc_bk0_n_bk1_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.ygrad_grid_desc_m0_o_m1_, arg.ygrad_grid_desc_m0_o_m1_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.h_ratio_,
arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_), arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_),
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
...@@ -1280,13 +1346,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1280,13 +1346,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t b_g = arg.b_grid_desc_g_n_k_.GetLength(I0);
const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = arg.y_grid_desc_m_o_.GetLength(I1); const index_t c_gemm1n = arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = const index_t b1_gemm1n =
arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2); arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0))
{ {
return false; return false;
} }
...@@ -1390,6 +1457,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1390,6 +1457,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_strides,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
...@@ -1430,6 +1501,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1430,6 +1501,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
bgrad_gs_ns_ks_lengths,
bgrad_gs_ns_ks_strides,
b1grad_gs_gemm1ns_gemm1ks_lengths,
b1grad_gs_gemm1ns_gemm1ks_strides,
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
...@@ -1474,6 +1549,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1474,6 +1549,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_strides,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
...@@ -1515,6 +1594,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1515,6 +1594,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
bgrad_gs_ns_ks_lengths,
bgrad_gs_ns_ks_strides,
b1grad_gs_gemm1ns_gemm1ks_lengths,
b1grad_gs_gemm1ns_gemm1ks_strides,
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_bias_gs_ms_gemm1ns_lengths, acc1_bias_gs_ms_gemm1ns_lengths,
......
...@@ -74,16 +74,19 @@ __global__ void ...@@ -74,16 +74,19 @@ __global__ void
const CElementwiseOperation c_element_op, const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const BGridDesc_BK0_N_BK1 bgrad_grid_desc_bk0_n_bk1,
const D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3, const D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1grad_grid_desc_bk0_n_bk1,
const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock const YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
const YGradGridDesc_O0_M_O1 ygrad_grid_desc_o0_m_o1, const YGradGridDesc_O0_M_O1 ygrad_grid_desc_o0_m_o1,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const index_t h_ratio,
const index_t nblock, const index_t nblock,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask, const C0MatrixMask c0_matrix_mask,
...@@ -99,21 +102,26 @@ __global__ void ...@@ -99,21 +102,26 @@ __global__ void
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio);
// NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch // NOTE: assumes QKVY has the same layout as dQ/dK/dV/dY therefore being able to reuse batch
// offsets // offsets
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(gkv_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetZBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(gkv_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetLSEBasePtr(g_idx)));
const long_index_t bgrad_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBGradBasePtr(g_idx)));
const long_index_t b1grad_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1GradBasePtr(g_idx)));
ck::philox ph(seed, 0, offset); ck::philox ph(seed, 0, offset);
ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset); ZDataType* z_matrix_ptr = (p_z_grid == nullptr ? nullptr : p_z_grid + z_batch_offset);
...@@ -149,9 +157,9 @@ __global__ void ...@@ -149,9 +157,9 @@ __global__ void
p_lse_grid + lse_batch_offset, p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + bgrad_batch_offset,
tmp_p_d0grad_grid, tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1grad_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -160,9 +168,11 @@ __global__ void ...@@ -160,9 +168,11 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
bgrad_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
b1grad_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
ygrad_grid_desc_o0_m_o1, ygrad_grid_desc_o0_m_o1,
...@@ -187,9 +197,9 @@ __global__ void ...@@ -187,9 +197,9 @@ __global__ void
p_lse_grid + lse_batch_offset, p_lse_grid + lse_batch_offset,
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + bgrad_batch_offset,
tmp_p_d0grad_grid, tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1grad_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -198,9 +208,11 @@ __global__ void ...@@ -198,9 +208,11 @@ __global__ void
c_element_op, c_element_op,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
bgrad_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3, c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
b1grad_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
lse_grid_desc_m, lse_grid_desc_m,
ygrad_grid_desc_o0_m_o1, ygrad_grid_desc_o0_m_o1,
...@@ -232,14 +244,17 @@ __global__ void ...@@ -232,14 +244,17 @@ __global__ void
ignore = c_element_op; ignore = c_element_op;
ignore = a_grid_desc_ak0_m_ak1; ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1; ignore = b_grid_desc_bk0_n_bk1;
ignore = bgrad_grid_desc_bk0_n_bk1;
ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3; ignore = d0_grid_desc_m0_n0_m1_m2_n1_m3;
ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3; ignore = c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
ignore = b1_grid_desc_bk0_n_bk1; ignore = b1_grid_desc_bk0_n_bk1;
ignore = b1grad_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = lse_grid_desc_m; ignore = lse_grid_desc_m;
ignore = ygrad_grid_desc_o0_m_o1; ignore = ygrad_grid_desc_o0_m_o1;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = h_ratio;
ignore = nblock; ignore = nblock;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask; ignore = c0_matrix_mask;
...@@ -603,6 +618,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -603,6 +618,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const BGridDesc_G_N_K& bgrad_grid_desc_g_n_k,
const B1GridDesc_G_N_K& b1grad_grid_desc_g_n_k,
index_t BatchStrideLSE) index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
...@@ -610,6 +627,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -610,6 +627,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
bgrad_grid_desc_g_n_k_(bgrad_grid_desc_g_n_k),
b1grad_grid_desc_g_n_k_(b1grad_grid_desc_g_n_k),
BatchStrideLSE_(BatchStrideLSE) BatchStrideLSE_(BatchStrideLSE)
{ {
} }
...@@ -649,6 +668,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -649,6 +668,16 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return g_idx * static_cast<long_index_t>(BatchStrideLSE_); return g_idx * static_cast<long_index_t>(BatchStrideLSE_);
} }
__host__ __device__ constexpr long_index_t GetBGradBasePtr(index_t g_idx) const
{
return bgrad_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetB1GradBasePtr(index_t g_idx) const
{
return b1grad_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
...@@ -656,6 +685,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -656,6 +685,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
BGridDesc_G_N_K bgrad_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1grad_grid_desc_g_n_k_;
index_t BatchStrideLSE_; index_t BatchStrideLSE_;
}; };
...@@ -755,6 +786,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -755,6 +786,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_strides,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
...@@ -784,9 +819,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -784,9 +819,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
bgrad_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
bgrad_gs_ns_ks_lengths, bgrad_gs_ns_ks_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1( b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
b1grad_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1(
b1grad_gs_gemm1ns_gemm1ks_lengths, b1grad_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])}, lse_grid_desc_m_{DeviceOp::MakeLSEGridDescriptor_M(lse_gs_ms_lengths[NumDimG])},
...@@ -805,6 +844,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -805,6 +844,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{ z_grid_desc_g_m_n_{
Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)}, Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
bgrad_grid_desc_g_n_k_{Transform::MakeB0GridDescriptor_G_N_K(bgrad_gs_ns_ks_lengths,
bgrad_gs_ns_ks_strides)},
b1grad_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K(
b1grad_gs_gemm1ns_gemm1ks_lengths, b1grad_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_mblock_mperblock_oblock_operblock_{}, y_grid_desc_mblock_mperblock_oblock_operblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)}, block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
...@@ -826,6 +869,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -826,6 +869,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, batch_count_{c_grid_desc_g_m_n_.GetLength(I0)},
h_ratio_{c_grid_desc_g_m_n_.GetLength(I0) / b_grid_desc_g_n_k_.GetLength(I0)},
p_drop_{p_drop} p_drop_{p_drop}
{ {
// TODO: implement bias addition // TODO: implement bias addition
...@@ -864,6 +908,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -864,6 +908,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
z_grid_desc_g_m_n_, z_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_, b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_, c_grid_desc_g_m_n_,
bgrad_grid_desc_g_n_k_,
b1grad_grid_desc_g_n_k_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize())); type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize()));
seed_ = std::get<0>(seeds); seed_ = std::get<0>(seeds);
...@@ -887,7 +933,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -887,7 +933,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<< b_grid_desc_g_n_k_.GetLength(I1) << ", " << b_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b_grid_desc_g_n_k_.GetLength(I2) << '\n'; << b_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b_grid_desc_g_n_k_.Print(); // b_grid_desc_g_n_k_.Print();
std::cout << "b1_grid_desc_g_o_n_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", " std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I1) << ", " << b1_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1_grid_desc_g_n_k_.GetLength(I2) << '\n'; << b1_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b1_grid_desc_g_n_k_.Print(); // b1_grid_desc_g_n_k_.Print();
...@@ -900,6 +946,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -900,6 +946,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::cout << "ygrad_grid_desc_o0_m_o1_: " << ygrad_grid_desc_o0_m_o1_.GetLength(I0) std::cout << "ygrad_grid_desc_o0_m_o1_: " << ygrad_grid_desc_o0_m_o1_.GetLength(I0)
<< ", " << ygrad_grid_desc_o0_m_o1_.GetLength(I1) << ", " << ", " << ygrad_grid_desc_o0_m_o1_.GetLength(I1) << ", "
<< ygrad_grid_desc_o0_m_o1_.GetLength(I2) << '\n'; << ygrad_grid_desc_o0_m_o1_.GetLength(I2) << '\n';
std::cout << "d0_grid_desc_g_m_n_: " << d0_grid_desc_g_m_n_.GetLength(I0) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I1) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I2) << '\n';
std::cout << "bgrad_grid_desc_g_n_k_: " << bgrad_grid_desc_g_n_k_.GetLength(I0) << ", "
<< bgrad_grid_desc_g_n_k_.GetLength(I1) << ", "
<< bgrad_grid_desc_g_n_k_.GetLength(I2) << '\n';
// bgrad_grid_desc_g_n_k_.Print();
std::cout << "b1grad_grid_desc_g_n_k_: " << b1grad_grid_desc_g_n_k_.GetLength(I0)
<< ", " << b1grad_grid_desc_g_n_k_.GetLength(I1) << ", "
<< b1grad_grid_desc_g_n_k_.GetLength(I2) << '\n';
// b1grad_grid_desc_g_n_k_.Print();
} }
// pointers // pointers
...@@ -919,9 +976,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -919,9 +976,11 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
BGridDesc_BK0_N_BK1 bgrad_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_; typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1grad_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
LSEGridDesc_M lse_grid_desc_m_; LSEGridDesc_M lse_grid_desc_m_;
KGridDesc_N_K k_grid_desc_n_k_; KGridDesc_N_K k_grid_desc_n_k_;
...@@ -934,6 +993,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -934,6 +993,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
BGridDesc_G_N_K bgrad_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1grad_grid_desc_g_n_k_;
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_; y_grid_desc_mblock_mperblock_oblock_operblock_;
...@@ -961,6 +1022,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -961,6 +1022,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<index_t> c_mz_gemm1nz_strides_; std::vector<index_t> c_mz_gemm1nz_strides_;
index_t batch_count_; index_t batch_count_;
index_t h_ratio_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
float p_drop_; float p_drop_;
...@@ -1047,14 +1109,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1047,14 +1109,17 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
arg.c_element_op_, arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.bgrad_grid_desc_bk0_n_bk1_,
arg.d0_grid_desc_m0_n0_m1_m2_n1_m3_, arg.d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg.c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.b1grad_grid_desc_bk0_n_bk1_,
arg.y_grid_desc_mblock_mperblock_oblock_operblock_, arg.y_grid_desc_mblock_mperblock_oblock_operblock_,
arg.lse_grid_desc_m_, arg.lse_grid_desc_m_,
arg.ygrad_grid_desc_o0_m_o1_, arg.ygrad_grid_desc_o0_m_o1_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.h_ratio_,
arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_), arg.block_2_ctile_map_.CalculateGridSize(arg.k_grid_desc_n_k_),
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_, arg.c0_matrix_mask_,
...@@ -1108,13 +1173,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1108,13 +1173,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t b_g = arg.b_grid_desc_g_n_k_.GetLength(I0);
const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = arg.y_grid_desc_m_o_.GetLength(I1); const index_t c_gemm1n = arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = const index_t b1_gemm1n =
arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2); arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0))
{ {
return false; return false;
} }
...@@ -1214,6 +1280,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1214,6 +1280,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_strides,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
...@@ -1253,6 +1323,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1253,6 +1323,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
bgrad_gs_ns_ks_lengths,
bgrad_gs_ns_ks_strides,
b1grad_gs_gemm1ns_gemm1ks_lengths,
b1grad_gs_gemm1ns_gemm1ks_strides,
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths acc1_bias_gs_ms_gemm1ns_lengths, // acc1_bias_gs_ms_os_lengths
...@@ -1296,6 +1370,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1296,6 +1370,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector<index_t>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides const std::vector<index_t>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_lengths,
const std::vector<index_t>& bgrad_gs_ns_ks_strides,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_lengths,
const std::vector<index_t>& b1grad_gs_gemm1ns_gemm1ks_strides,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides, const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides,
const std::vector<ck::index_t>& const std::vector<ck::index_t>&
...@@ -1336,6 +1414,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1336,6 +1414,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
lse_gs_ms_lengths, lse_gs_ms_lengths,
bgrad_gs_ns_ks_lengths,
bgrad_gs_ns_ks_strides,
b1grad_gs_gemm1ns_gemm1ks_lengths,
b1grad_gs_gemm1ns_gemm1ks_strides,
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides, acc0_bias_gs_ms_ns_strides,
acc1_bias_gs_ms_gemm1ns_lengths, acc1_bias_gs_ms_gemm1ns_lengths,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment