Commit 10836d41 authored by danyao12's avatar danyao12
Browse files

G1/G2 -> G1Q/G1KV

parent 9574b34d
......@@ -269,15 +269,15 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4;
ck::index_t G1 = 6; // h_q
ck::index_t G2 = 6; // h_kv
// y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4;
ck::index_t G1Q = 6; // h_q
ck::index_t G1KV = 6; // h_kv
bool input_permute = false;
bool output_permute = false;
......@@ -302,13 +302,13 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
G2 = std::stoi(argv[10]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1Q = std::stoi(argv[9]);
G1KV = std::stoi(argv[10]);
p_drop = std::stof(argv[11]);
......@@ -320,7 +320,7 @@ int run(int argc, char* argv[])
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 10: M, N, K, O, G0, G1, G2\n");
printf("arg4 to 10: M, N, K, O, G0, G1Q, G1KV\n");
printf("arg11: p_drop\n");
printf("arg12 to 13: input / output permute\n");
exit(0);
......@@ -339,8 +339,8 @@ int run(int argc, char* argv[])
std::cout << "K: " << K << std::endl;
std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl;
std::cout << "G2: " << G2 << std::endl;
std::cout << "G1Q: " << G1Q << std::endl;
std::cout << "G1KV: " << G1KV << std::endl;
std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl;
......@@ -348,57 +348,57 @@ int run(int argc, char* argv[])
std::cout << "seed: " << seed << std::endl;
std::cout << "offset: " << offset << std::endl;
const ck::index_t BatchCount = G0 * G1;
const ck::index_t BatchCount = G0 * G1Q;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // Q layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // Q layout [G0, G1Q, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // K layout [G0, N, G2, K]
: std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // K layout [G0, G2, N, K]
? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1} // K layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // K layout [G0, G1KV, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // V layout [G0, N, G2, O]
: std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // V layout [G0, G2, N, O]
? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O} // V layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // V layout [G0, G1KV, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // Y layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // Y layout [G0, G1Q, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1Q, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // KGrad layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // KGrad layout [G0, G1, N, K]
? std::vector<ck::index_t>{N * G1Q * K, K, G1Q * K, 1} // KGrad layout [G0, N, G1Q, K]
: std::vector<ck::index_t>{G1Q * N * K, N * K, K, 1}; // KGrad layout [G0, G1Q, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1Q, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // VGrad layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // VGrad layout [G0, G1, N, O]
? std::vector<ck::index_t>{N * G1Q * O, O, 1, G1Q * O} // VGrad layout [G0, N, G1Q, O]
: std::vector<ck::index_t>{G1Q * N * O, N * O, 1, O}; // VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
......@@ -451,14 +451,14 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1q, m, o]
// dO dot O = [0; 1; 2; ...]
break;
case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
......@@ -471,7 +471,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
......@@ -493,20 +494,20 @@ int run(int argc, char* argv[])
Tensor<LSEDataType> lse_g_m({BatchCount, M});
q_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx); });
k_g_n_k.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = k_gs_ns_ks(g0, g2, idx[1], idx[2]);
self(idx) = k_gs_ns_ks(g0, g1kv, idx[1], idx[2]);
});
v_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = v_gs_os_ns(g0, g2, idx[2], idx[1]);
self(idx) = v_gs_os_ns(g0, g1kv, idx[2], idx[1]);
});
// qkv gradients have the same descriptor as with qkv
......@@ -651,7 +652,7 @@ int run(int argc, char* argv[])
// copy z matirx data form device
z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
z_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx); });
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool pass = true;
......@@ -671,10 +672,10 @@ int run(int argc, char* argv[])
p_dropout_in_uint8_t,
rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
self(idx) = y_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]);
});
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1Q + idx[1], idx[2]); });
y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data());
......@@ -692,7 +693,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> ygrad_dot_y_g_m({BatchCount, M});
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
ygrad_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
#if PRINT_HOST
......@@ -811,26 +812,26 @@ int run(int argc, char* argv[])
// permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
});
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
});
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
});
......
......@@ -270,15 +270,15 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4;
ck::index_t G1 = 6; // h_q
ck::index_t G2 = 6; // h_kv
// y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4;
ck::index_t G1Q = 6; // h_q
ck::index_t G1KV = 6; // h_kv
bool input_permute = false;
bool output_permute = false;
......@@ -303,13 +303,13 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
G2 = std::stoi(argv[10]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1Q = std::stoi(argv[9]);
G1KV = std::stoi(argv[10]);
p_drop = std::stof(argv[11]);
......@@ -321,7 +321,7 @@ int run(int argc, char* argv[])
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 10: M, N, K, O, G0, G1, G2\n");
printf("arg4 to 10: M, N, K, O, G0, G1Q, G1KV\n");
printf("arg11: p_drop\n");
printf("arg12 to 13: input / output permute\n");
exit(0);
......@@ -340,8 +340,8 @@ int run(int argc, char* argv[])
std::cout << "K: " << K << std::endl;
std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl;
std::cout << "G2: " << G2 << std::endl;
std::cout << "G1Q: " << G1Q << std::endl;
std::cout << "G1KV: " << G1KV << std::endl;
std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl;
......@@ -349,57 +349,57 @@ int run(int argc, char* argv[])
std::cout << "seed: " << seed << std::endl;
std::cout << "offset: " << offset << std::endl;
const ck::index_t BatchCount = G0 * G1;
const ck::index_t BatchCount = G0 * G1Q;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // Q layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // Q layout [G0, G1Q, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // K layout [G0, N, G2, K]
: std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // K layout [G0, G2, N, K]
? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1} // K layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // K layout [G0, G1KV, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // V layout [G0, N, G2, O]
: std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // V layout [G0, G2, N, O]
? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O} // V layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // V layout [G0, G1KV, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // Y layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // Y layout [G0, G1Q, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1Q, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // KGrad layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // KGrad layout [G0, G1, N, K]
? std::vector<ck::index_t>{N * G1Q * K, K, G1Q * K, 1} // KGrad layout [G0, N, G1Q, K]
: std::vector<ck::index_t>{G1Q * N * K, N * K, K, 1}; // KGrad layout [G0, G1Q, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1Q, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // VGrad layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // VGrad layout [G0, G1, N, O]
? std::vector<ck::index_t>{N * G1Q * O, O, 1, G1Q * O} // VGrad layout [G0, N, G1Q, O]
: std::vector<ck::index_t>{G1Q * N * O, N * O, 1, O}; // VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
......@@ -454,14 +454,14 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1q, m, o]
// dO dot O = [0; 1; 2; ...]
break;
case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
......@@ -474,7 +474,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
......@@ -496,20 +497,20 @@ int run(int argc, char* argv[])
Tensor<LSEDataType> lse_g_m({BatchCount, M});
q_gs_ms_ks.ForEach(
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
[&](auto& self, auto idx) { q_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx); });
k_g_n_k.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = k_gs_ns_ks(g0, g2, idx[1], idx[2]);
self(idx) = k_gs_ns_ks(g0, g1kv, idx[1], idx[2]);
});
v_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = v_gs_os_ns(g0, g2, idx[2], idx[1]);
self(idx) = v_gs_os_ns(g0, g1kv, idx[2], idx[1]);
});
// qkv gradients have the same descriptor as with qkv
......@@ -657,7 +658,7 @@ int run(int argc, char* argv[])
// copy z matirx data form device
z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
z_gs_ms_ns.ForEach(
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); });
[&](auto& self, auto idx) { z_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx); });
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool pass = true;
......@@ -677,10 +678,10 @@ int run(int argc, char* argv[])
p_dropout_in_uint8_t,
rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
self(idx) = y_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]);
});
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1Q + idx[1], idx[2]); });
y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data());
......@@ -698,7 +699,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> ygrad_dot_y_g_m({BatchCount, M});
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
ygrad_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
#if PRINT_HOST
......@@ -817,26 +818,26 @@ int run(int argc, char* argv[])
// permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
});
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
});
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
});
......
......@@ -299,15 +299,15 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t N = 500; // 512
ck::index_t M = 500; // 512
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4;
ck::index_t G1 = 6; // h_q
ck::index_t G2 = 6; // h_kv
// y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
ck::index_t N = 500; // 512
ck::index_t M = 500; // 512
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4;
ck::index_t G1Q = 6; // h_q
ck::index_t G1KV = 6; // h_kv
bool input_permute = false;
bool output_permute = false;
......@@ -332,13 +332,13 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
G2 = std::stoi(argv[10]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1Q = std::stoi(argv[9]);
G1KV = std::stoi(argv[10]);
p_drop = std::stof(argv[11]);
......@@ -350,7 +350,7 @@ int run(int argc, char* argv[])
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 10: M, N, K, O, G0, G1, G2\n");
printf("arg4 to 10: M, N, K, O, G0, G1Q, G1KV\n");
printf("arg11: p_drop\n");
printf("arg12 to 13: input / output permute\n");
exit(0);
......@@ -369,8 +369,8 @@ int run(int argc, char* argv[])
std::cout << "K: " << K << std::endl;
std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl;
std::cout << "G2: " << G2 << std::endl;
std::cout << "G1Q: " << G1Q << std::endl;
std::cout << "G1KV: " << G1KV << std::endl;
std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl;
......@@ -378,57 +378,57 @@ int run(int argc, char* argv[])
std::cout << "seed: " << seed << std::endl;
std::cout << "offset: " << offset << std::endl;
const ck::index_t BatchCount = G0 * G1;
const ck::index_t BatchCount = G0 * G1Q;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // Q layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // Q layout [G0, G1Q, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // K layout [G0, N, G2, K]
: std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // K layout [G0, G2, N, K]
? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1} // K layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // K layout [G0, G1KV, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // V layout [G0, N, G2, O]
: std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // V layout [G0, G2, N, O]
? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O} // V layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // V layout [G0, G1KV, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // Y layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // Y layout [G0, G1Q, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1Q, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // KGrad layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // KGrad layout [G0, G1, N, K]
? std::vector<ck::index_t>{N * G1Q * K, K, G1Q * K, 1} // KGrad layout [G0, N, G1Q, K]
: std::vector<ck::index_t>{G1Q * N * K, N * K, K, 1}; // KGrad layout [G0, G1Q, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1Q, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // VGrad layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // VGrad layout [G0, G1, N, O]
? std::vector<ck::index_t>{N * G1Q * O, O, 1, G1Q * O} // VGrad layout [G0, N, G1Q, O]
: std::vector<ck::index_t>{G1Q * N * O, N * O, 1, O}; // VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
......@@ -484,14 +484,14 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1q, m, o]
// dO dot O = [0; 1; 2; ...]
break;
case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
......@@ -504,7 +504,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
......@@ -820,24 +821,24 @@ int run(int argc, char* argv[])
}
q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
q_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
k_g_n_k.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = k_gs_ns_ks(g0, g2, idx[1], idx[2]);
self(idx) = k_gs_ns_ks(g0, g1kv, idx[1], idx[2]);
});
v_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = v_gs_os_ns(g0, g2, idx[2], idx[1]);
self(idx) = v_gs_os_ns(g0, g1kv, idx[2], idx[1]);
});
z_fwd_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_fwd_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
z_fwd_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
run_attention_fwd_host(q_g_m_k,
......@@ -854,10 +855,10 @@ int run(int argc, char* argv[])
rp_dropout);
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
ygrad_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
z_bwd_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_bwd_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
z_bwd_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
#if PRINT_HOST
......@@ -960,42 +961,42 @@ int run(int argc, char* argv[])
// permute
y_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = y_g_m_o(g, idx[2], idx[3]);
});
lse_gs_ms_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = lse_g_m(g, idx[2]);
});
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
});
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
});
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
});
......
......@@ -268,11 +268,11 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
float alpha = 1.f / std::sqrt(DIM);
float p_drop = 0.0;
int h_ratio = 1; // G1 / G2
int h_ratio = 1; // G1Q / G1KV
bool input_permute = true;
bool output_permute = true;
......@@ -369,61 +369,65 @@ int run(int argc, char* argv[])
std::size_t flop = 0, num_byte = 0;
for(std::size_t i = 0; i < group_count; i++)
{
int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM;
int O = DIM;
int G0 = rand() % 4 + 1;
int G2 = rand() % 4 + 1;
int G1 = G2 * h_ratio;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM;
int O = DIM;
int G0 = rand() % 4 + 1;
int G1KV = rand() % 4 + 1;
int G1Q = G1KV * h_ratio;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // Q layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // Q layout [G0, G1Q, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // K layout [G0, N, G2, K]
: std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // K layout [G0, G2, N, K]
? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1}
// K layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // K layout [G0, G1KV, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // V layout [G0, N, G2, O]
: std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // V layout [G0, G2, N, O]
? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O}
// V layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // V layout [G0, G1KV, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // Y layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // Y layout [G0, G1Q, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1, N, K};
? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1Q, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // KGrad layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // KGrad layout [G0, G1, N, K]
input_permute ? std::vector<ck::index_t>{N * G1Q * K, K, G1Q * K, 1}
// KGrad layout [G0, N, G1Q, K]
: std::vector<ck::index_t>{
G1Q * N * K, N * K, K, 1}; // KGrad layout [G0, G1Q, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1Q, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // VGrad layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // VGrad layout [G0, G1, N, O]
input_permute ? std::vector<ck::index_t>{N * G1Q * O, O, 1, G1Q * O}
// VGrad layout [G0, N, G1Q, O]
: std::vector<ck::index_t>{
G1Q * N * O, N * O, 1, O}; // VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
problem_descs.push_back({
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
......@@ -447,7 +451,7 @@ int run(int argc, char* argv[])
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
});
int BatchCount = G0 * G1;
int BatchCount = G0 * G1Q;
flop += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte += (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
......@@ -509,14 +513,16 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_Sequential<2>{}); // dy[g0, g1q, m, o]
// dO dot O = [0; 1; 2; ...]
break;
case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_Sequential<3>{}); // dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
......@@ -530,7 +536,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o]
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
......@@ -551,21 +557,21 @@ int run(int argc, char* argv[])
Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
q_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
k_g_n_k.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = k_gs_ns_ks(g0, g2, idx[1], idx[2]);
self(idx) = k_gs_ns_ks(g0, g1kv, idx[1], idx[2]);
});
v_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = v_gs_os_ns(g0, g2, idx[2], idx[1]);
self(idx) = v_gs_os_ns(g0, g1kv, idx[2], idx[1]);
});
q_g_m_ks.push_back(q_g_m_k);
......@@ -706,11 +712,11 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++)
{
int G1 = q_tensors[i].GetLengths()[1];
int G1Q = q_tensors[i].GetLengths()[1];
// copy z matirx data form device
z_tensors_device[i]->FromDevice(z_tensors[i].mData.data());
z_tensors[i].ForEach([&](auto& self, auto idx) {
z_g_m_ns[i](idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
z_g_m_ns[i](idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
run_attention_fwd_host(q_g_m_ks[i],
k_g_n_ks[i],
......@@ -726,11 +732,11 @@ int run(int argc, char* argv[])
rp_dropout);
y_tensors[i].ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_os[i](idx[0] * G1 + idx[1], idx[2], idx[3]);
self(idx) = y_g_m_os[i](idx[0] * G1Q + idx[1], idx[2], idx[3]);
});
y_tensors_device[i]->ToDevice(y_tensors[i].data());
lse_tensors[i].ForEach([&](auto& self, auto idx) {
self(idx) = lse_g_ms[i](idx[0] * G1 + idx[1], idx[2]);
self(idx) = lse_g_ms[i](idx[0] * G1Q + idx[1], idx[2]);
});
lse_tensors_device[i]->ToDevice(lse_tensors[i].data());
qgrad_tensors_device[i]->SetZero();
......@@ -744,12 +750,12 @@ int run(int argc, char* argv[])
{
int G0 = q_tensors[i].GetLengths()[0];
int G1 = q_tensors[i].GetLengths()[1];
int G1Q = q_tensors[i].GetLengths()[1];
int O = v_tensors[i].GetLengths()[2];
int N = v_tensors[i].GetLengths()[3];
int M = q_tensors[i].GetLengths()[2];
int K = q_tensors[i].GetLengths()[3];
int BatchCount = G0 * G1;
int BatchCount = G0 * G1Q;
Tensor<OutputDataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<OutputDataType> kgrad_g_n_k({BatchCount, N, K});
Tensor<OutputDataType> vgrad_g_n_o({BatchCount, N, O});
......@@ -759,7 +765,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> ygrad_g_m_o({BatchCount, M, O});
ygrad_tensors[i].ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
ygrad_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
auto ref_gemm0_grad = ReferenceGemm0GradInstance{};
auto ref_gemm0_grad_invoker = ref_gemm0_grad.MakeInvoker();
......@@ -819,26 +825,26 @@ int run(int argc, char* argv[])
vgrad_tensors_device[i]->FromDevice(vgrad_gs_os_ns_device_result.data());
// permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
});
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
});
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
});
......
......@@ -269,11 +269,11 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
float alpha = 1.f / std::sqrt(DIM);
float p_drop = 0.0;
int h_ratio = 1; // G1 / G2
int h_ratio = 1; // G1Q / G1KV
bool input_permute = true;
bool output_permute = true;
......@@ -373,61 +373,65 @@ int run(int argc, char* argv[])
std::size_t flop = 0, num_byte = 0;
for(std::size_t i = 0; i < group_count; i++)
{
int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM;
int O = DIM;
int G0 = rand() % 4 + 1;
int G2 = rand() % 4 + 1;
int G1 = G2 * h_ratio;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM;
int O = DIM;
int G0 = rand() % 4 + 1;
int G1KV = rand() % 4 + 1;
int G1Q = G1KV * h_ratio;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // Q layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // Q layout [G0, G1Q, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // K layout [G0, N, G2, K]
: std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // K layout [G0, G2, N, K]
? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1}
// K layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // K layout [G0, G1KV, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // V layout [G0, N, G2, O]
: std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // V layout [G0, G2, N, O]
? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O}
// V layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // V layout [G0, G1KV, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // Y layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // Y layout [G0, G1Q, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1, N, K};
? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1Q, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // KGrad layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // KGrad layout [G0, G1, N, K]
input_permute ? std::vector<ck::index_t>{N * G1Q * K, K, G1Q * K, 1}
// KGrad layout [G0, N, G1Q, K]
: std::vector<ck::index_t>{
G1Q * N * K, N * K, K, 1}; // KGrad layout [G0, G1Q, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1Q, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // VGrad layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // VGrad layout [G0, G1, N, O]
input_permute ? std::vector<ck::index_t>{N * G1Q * O, O, 1, G1Q * O}
// VGrad layout [G0, N, G1Q, O]
: std::vector<ck::index_t>{
G1Q * N * O, N * O, 1, O}; // VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
problem_descs.push_back({
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
......@@ -451,7 +455,7 @@ int run(int argc, char* argv[])
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
});
int BatchCount = G0 * G1;
int BatchCount = G0 * G1Q;
flop += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte += (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
......@@ -515,14 +519,16 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_Sequential<2>{}); // dy[g0, g1q, m, o]
// dO dot O = [0; 1; 2; ...]
break;
case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_Sequential<3>{}); // dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
......@@ -536,7 +542,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o]
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
......@@ -558,21 +564,21 @@ int run(int argc, char* argv[])
Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
q_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
k_g_n_k.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = k_gs_ns_ks(g0, g2, idx[1], idx[2]);
self(idx) = k_gs_ns_ks(g0, g1kv, idx[1], idx[2]);
});
v_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = v_gs_os_ns(g0, g2, idx[2], idx[1]);
self(idx) = v_gs_os_ns(g0, g1kv, idx[2], idx[1]);
});
q_g_m_ks.push_back(q_g_m_k);
......@@ -719,11 +725,11 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++)
{
int G1 = q_tensors[i].GetLengths()[1];
int G1Q = q_tensors[i].GetLengths()[1];
// copy z matirx data form device
z_tensors_device[i]->FromDevice(z_tensors[i].mData.data());
z_tensors[i].ForEach([&](auto& self, auto idx) {
z_g_m_ns[i](idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
z_g_m_ns[i](idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
run_attention_fwd_host(q_g_m_ks[i],
k_g_n_ks[i],
......@@ -739,11 +745,11 @@ int run(int argc, char* argv[])
rp_dropout);
y_tensors[i].ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_os[i](idx[0] * G1 + idx[1], idx[2], idx[3]);
self(idx) = y_g_m_os[i](idx[0] * G1Q + idx[1], idx[2], idx[3]);
});
y_tensors_device[i]->ToDevice(y_tensors[i].data());
lse_tensors[i].ForEach([&](auto& self, auto idx) {
self(idx) = lse_g_ms[i](idx[0] * G1 + idx[1], idx[2]);
self(idx) = lse_g_ms[i](idx[0] * G1Q + idx[1], idx[2]);
});
lse_tensors_device[i]->ToDevice(lse_tensors[i].data());
qgrad_tensors_device[i]->SetZero();
......@@ -757,12 +763,12 @@ int run(int argc, char* argv[])
{
int G0 = q_tensors[i].GetLengths()[0];
int G1 = q_tensors[i].GetLengths()[1];
int G1Q = q_tensors[i].GetLengths()[1];
int O = v_tensors[i].GetLengths()[2];
int N = v_tensors[i].GetLengths()[3];
int M = q_tensors[i].GetLengths()[2];
int K = q_tensors[i].GetLengths()[3];
int BatchCount = G0 * G1;
int BatchCount = G0 * G1Q;
Tensor<OutputDataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<OutputDataType> kgrad_g_n_k({BatchCount, N, K});
Tensor<OutputDataType> vgrad_g_n_o({BatchCount, N, O});
......@@ -772,7 +778,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> ygrad_g_m_o({BatchCount, M, O});
ygrad_tensors[i].ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
ygrad_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
auto ref_gemm0_grad = ReferenceGemm0GradInstance{};
auto ref_gemm0_grad_invoker = ref_gemm0_grad.MakeInvoker();
......@@ -832,26 +838,26 @@ int run(int argc, char* argv[])
vgrad_tensors_device[i]->FromDevice(vgrad_gs_os_ns_device_result.data());
// permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
});
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
});
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
});
......
......@@ -298,11 +298,11 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
float alpha = 1.f / std::sqrt(DIM);
float p_drop = 0.2;
int h_ratio = 1; // G1 / G2
int h_ratio = 1; // G1Q / G1KV
bool input_permute = true;
bool output_permute = true;
......@@ -409,61 +409,65 @@ int run(int argc, char* argv[])
std::size_t flop_bwd = 0, num_byte_bwd = 0;
for(std::size_t i = 0; i < group_count; i++)
{
int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM;
int O = DIM;
int G0 = rand() % 4 + 1;
int G2 = rand() % 4 + 1;
int G1 = G2 * h_ratio;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM;
int O = DIM;
int G0 = rand() % 4 + 1;
int G1KV = rand() % 4 + 1;
int G1Q = G1KV * h_ratio;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // Q layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // Q layout [G0, G1Q, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // K layout [G0, N, G2, K]
: std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // K layout [G0, G2, N, K]
? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1}
// K layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // K layout [G0, G1KV, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // V layout [G0, N, G2, O]
: std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // V layout [G0, G2, N, O]
? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O}
// V layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // V layout [G0, G1KV, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // Y layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // Y layout [G0, G1Q, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1, N, K};
? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1Q, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // KGrad layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // KGrad layout [G0, G1, N, K]
input_permute ? std::vector<ck::index_t>{N * G1Q * K, K, G1Q * K, 1}
// KGrad layout [G0, N, G1Q, K]
: std::vector<ck::index_t>{
G1Q * N * K, N * K, K, 1}; // KGrad layout [G0, G1Q, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1Q, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // VGrad layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // VGrad layout [G0, G1, N, O]
input_permute ? std::vector<ck::index_t>{N * G1Q * O, O, 1, G1Q * O}
// VGrad layout [G0, N, G1Q, O]
: std::vector<ck::index_t>{
G1Q * N * O, N * O, 1, O}; // VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
problem_descs_fwd.push_back({
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
......@@ -505,7 +509,7 @@ int run(int argc, char* argv[])
{}, // std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
});
int BatchCount = G0 * G1;
int BatchCount = G0 * G1Q;
flop_fwd += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
num_byte_fwd += (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O) *
......@@ -574,14 +578,16 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_Sequential<2>{}); // dy[g0, g1q, m, o]
// dO dot O = [0; 1; 2; ...]
break;
case 6:
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_Sequential<3>{}); // dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
......@@ -595,7 +601,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o]
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1q, m, o]
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
// O = P V = 0.0039 * ones
......@@ -618,21 +624,21 @@ int run(int argc, char* argv[])
Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
q_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
k_g_n_k.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = k_gs_ns_ks(g0, g2, idx[1], idx[2]);
self(idx) = k_gs_ns_ks(g0, g1kv, idx[1], idx[2]);
});
v_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = v_gs_os_ns(g0, g2, idx[2], idx[1]);
self(idx) = v_gs_os_ns(g0, g1kv, idx[2], idx[1]);
});
q_g_m_ks.push_back(q_g_m_k);
......@@ -872,15 +878,15 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++)
{
int G1 = q_tensors[i].GetLengths()[1];
int G1Q = q_tensors[i].GetLengths()[1];
// copy z matirx data form device
z_fwd_tensors_device[i]->FromDevice(z_fwd_tensors[i].mData.data());
z_fwd_tensors[i].ForEach([&](auto& self, auto idx) {
z_fwd_g_m_ns[i](idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
z_fwd_g_m_ns[i](idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
z_bwd_tensors_device[i]->FromDevice(z_bwd_tensors[i].mData.data());
z_bwd_tensors[i].ForEach([&](auto& self, auto idx) {
z_bwd_g_m_ns[i](idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
z_bwd_g_m_ns[i](idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
run_attention_fwd_host(q_g_m_ks[i],
k_g_n_ks[i],
......@@ -900,7 +906,7 @@ int run(int argc, char* argv[])
int N = v_tensors[i].GetLengths()[3];
int M = q_tensors[i].GetLengths()[2];
int K = q_tensors[i].GetLengths()[3];
int BatchCount = G0 * G1;
int BatchCount = G0 * G1Q;
Tensor<OutputDataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<OutputDataType> kgrad_g_n_k({BatchCount, N, K});
Tensor<OutputDataType> vgrad_g_n_o({BatchCount, N, O});
......@@ -910,7 +916,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> ygrad_g_m_o({BatchCount, M, O});
ygrad_tensors[i].ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
ygrad_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
auto ref_gemm0_grad = ReferenceGemm0GradInstance{};
auto ref_gemm0_grad_invoker = ref_gemm0_grad.MakeInvoker();
......@@ -981,42 +987,42 @@ int run(int argc, char* argv[])
// permute
y_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = y_g_m_os[i](g, idx[2], idx[3]);
});
lse_gs_ms_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = lse_g_ms[i](g, idx[2]);
});
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
});
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
});
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
});
......
......@@ -14,12 +14,12 @@ int run(int argc, char* argv[])
ck::index_t K = DIM;
ck::index_t O = DIM;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 7;
ck::index_t G1 = 12; // h_q
ck::index_t G2 = 12; // h_kv
// Output shape C[G0, M, G1Q, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1q_m_o = reshape(C_g_m_o, [g0, g1q, m, o])
// C_g0_m_g1q_o = permute(C_g0_g1q_m_o, [0, 2, 1, 3])
ck::index_t G0 = 7;
ck::index_t G1Q = 12; // h_q
ck::index_t G1KV = 12; // h_kv
bool input_permute = false;
bool output_permute = true;
......@@ -44,13 +44,13 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
G2 = std::stoi(argv[10]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1Q = std::stoi(argv[9]);
G1KV = std::stoi(argv[10]);
p_drop = std::stof(argv[11]);
......@@ -62,7 +62,7 @@ int run(int argc, char* argv[])
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 10: M, N, K, O, G0, G1, G2\n");
printf("arg4 to 10: M, N, K, O, G0, G1Q, G1KV\n");
printf("arg11: p_drop\n");
printf("arg12 to 13: input / output permute\n");
exit(0);
......@@ -73,39 +73,39 @@ int run(int argc, char* argv[])
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // A layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // A layout [G0, G1Q, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // B0 layout [G0, N, G2, K]
: std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // B0 layout [G0, G2, N, K]
? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1} // B0 layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // B0 layout [G0, G1KV, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // B1 layout [G0, N, G2, O]
: std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // B1 layout [G0, G2, N, O]
? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O} // B1 layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // B1 layout [G0, G1KV, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // C layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // C layout [G0, G1Q, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides =
std::vector<ck::index_t>{G1 * M, M, 1}; // LSE layout [G0, G1, M]
std::vector<ck::index_t>{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
......@@ -213,7 +213,7 @@ int run(int argc, char* argv[])
return 0;
}
ck::index_t BatchCount = G0 * G1;
ck::index_t BatchCount = G0 * G1Q;
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
......@@ -278,32 +278,32 @@ int run(int argc, char* argv[])
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N});
Tensor<ADataType> a1_g_m_n_drop({BatchCount, M, N});
Tensor<LSEDataType> lse_g_m_host_result(
{BatchCount, M}); // scratch object after max + ln(sum)
Tensor<ZDataType> z_g_m_n({G0 * G1, M, N});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
// permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
a_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
b0_g_k_n.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = b0_gs_ns_ks(g0, g2, idx[2], idx[1]);
self(idx) = b0_gs_ns_ks(g0, g1kv, idx[2], idx[1]);
});
b1_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = b1_gs_os_ns(g0, g2, idx[2], idx[1]);
self(idx) = b1_gs_os_ns(g0, g1kv, idx[2], idx[1]);
});
z_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
z_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
// gemm 0
......@@ -350,18 +350,18 @@ int run(int argc, char* argv[])
// permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
lse_gs_ms_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = lse_g_m_host_result(g, idx[2]);
});
......
......@@ -11,7 +11,7 @@ int run(int argc, char* argv[])
bool output_permute = true;
float p_drop = 0.2;
int h_ratio = 1; // G1 / G2
int h_ratio = 1; // G1Q / G1KV
const unsigned long long seed = 1;
const unsigned long long offset = 0;
......@@ -64,7 +64,7 @@ int run(int argc, char* argv[])
std::vector<void*> p_z; // for result verification
std::vector<void*> p_z_nullptr; // for time test
std::vector<void*> p_lse;
std::vector<std::vector<int>> g0_g1_m_n_k_o;
std::vector<std::vector<int>> g0_g1q_m_n_k_o;
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<B0DataType>> b0_tensors;
......@@ -87,49 +87,51 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++)
{
int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM;
int O = DIM;
int G0 = rand() % 3 + 1;
int G2 = rand() % 5 + 1;
int G1 = G2 * h_ratio;
int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM;
int O = DIM;
int G0 = rand() % 3 + 1;
int G1KV = rand() % 5 + 1;
int G1Q = G1KV * h_ratio;
g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O});
g0_g1q_m_n_k_o.push_back({G0, G1Q, M, N, K, O});
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // A layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // A layout [G0, G1Q, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // B0 layout [G0, N, G2, K]
: std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // B0 layout [G0, G2, N, K]
? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1}
// B0 layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // B0 layout [G0, G1KV, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // B1 layout [G0, N, G2, O]
: std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // B1 layout [G0, G2, N, O]
? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O}
// B1 layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // B1 layout [G0, G1KV, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // C layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // C layout [G0, G1Q, M, O]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides =
std::vector<ck::index_t>{G1 * M, M, 1}; // LSE layout [G0, G1, M]
std::vector<ck::index_t>{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
problem_descs.push_back({a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
......@@ -156,7 +158,7 @@ int run(int argc, char* argv[])
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
int Batch = G0 * G1;
int Batch = G0 * G1Q;
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
......@@ -313,12 +315,12 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++)
{
const int& G0 = g0_g1_m_n_k_o[i][0];
const int& G1 = g0_g1_m_n_k_o[i][1];
const int& M = g0_g1_m_n_k_o[i][2];
const int& N = g0_g1_m_n_k_o[i][3];
const int& K = g0_g1_m_n_k_o[i][4];
const int& O = g0_g1_m_n_k_o[i][5];
const int& G0 = g0_g1q_m_n_k_o[i][0];
const int& G1Q = g0_g1q_m_n_k_o[i][1];
const int& M = g0_g1q_m_n_k_o[i][2];
const int& N = g0_g1q_m_n_k_o[i][3];
const int& K = g0_g1q_m_n_k_o[i][4];
const int& O = g0_g1q_m_n_k_o[i][5];
const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths;
const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides;
......@@ -339,39 +341,39 @@ int run(int argc, char* argv[])
z_gs_ms_ns_device_buf.FromDevice(z_gs_ms_ns_device_result.mData.data());
lse_gs_ms_device_buf.FromDevice(lse_gs_ms_device_result.mData.data());
Tensor<ADataType> a_g_m_k({G0 * G1, M, K});
Tensor<B0DataType> b0_g_k_n({G0 * G1, K, N});
Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O});
Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1
Tensor<ADataType> a_g_m_k({G0 * G1Q, M, K});
Tensor<B0DataType> b0_g_k_n({G0 * G1Q, K, N});
Tensor<B1DataType> b1_g_n_o({G0 * G1Q, N, O});
Tensor<AccDataType> acc0_g_m_n({G0 * G1Q, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({G0 * G1Q, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1Q, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({G0 * G1Q, M, O}); // scratch object after gemm1
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<ZDataType> z_g_m_n({G0 * G1, M, N});
Tensor<LSEDataType> lse_g_m_host_result({G0 * G1, M}); // scratch object after gemm1
Tensor<ZDataType> z_g_m_n({G0 * G1Q, M, N});
Tensor<LSEDataType> lse_g_m_host_result({G0 * G1Q, M}); // scratch object after gemm1
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
// permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
a_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
b0_g_k_n.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = b0_gs_ns_ks(g0, g2, idx[2], idx[1]);
self(idx) = b0_gs_ns_ks(g0, g1kv, idx[2], idx[1]);
});
b1_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = b1_gs_os_ns(g0, g2, idx[2], idx[1]);
self(idx) = b1_gs_os_ns(g0, g1kv, idx[2], idx[1]);
});
z_gs_ms_ns_device_result.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
z_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
// gemm 0
......@@ -421,18 +423,18 @@ int run(int argc, char* argv[])
// permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
lse_gs_ms_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = lse_g_m_host_result(g, idx[2]);
});
......
......@@ -273,15 +273,15 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4;
ck::index_t G1 = 6; // h_q
ck::index_t G2 = 6; // h_kv
// y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
ck::index_t M = 512;
ck::index_t N = 512;
ck::index_t K = DIM;
ck::index_t O = DIM;
ck::index_t G0 = 4;
ck::index_t G1Q = 6; // h_q
ck::index_t G1KV = 6; // h_kv
bool input_permute = false;
bool output_permute = false;
......@@ -306,13 +306,13 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
G2 = std::stoi(argv[10]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1Q = std::stoi(argv[9]);
G1KV = std::stoi(argv[10]);
p_drop = std::stof(argv[11]);
......@@ -324,7 +324,7 @@ int run(int argc, char* argv[])
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 10: M, N, K, O, G0, G1, G2\n");
printf("arg4 to 10: M, N, K, O, G0, G1Q, G1KV\n");
printf("arg11: p_drop\n");
printf("arg12 to 13: input / output permute\n");
exit(0);
......@@ -343,8 +343,8 @@ int run(int argc, char* argv[])
std::cout << "K: " << K << std::endl;
std::cout << "O: " << O << std::endl;
std::cout << "G0: " << G0 << std::endl;
std::cout << "G1: " << G1 << std::endl;
std::cout << "G2: " << G2 << std::endl;
std::cout << "G1Q: " << G1Q << std::endl;
std::cout << "G1KV: " << G1KV << std::endl;
std::cout << "alpha: " << alpha << std::endl;
std::cout << "input_permute: " << input_permute << std::endl;
std::cout << "output_permute: " << output_permute << std::endl;
......@@ -352,63 +352,63 @@ int run(int argc, char* argv[])
std::cout << "seed: " << seed << std::endl;
std::cout << "offset: " << offset << std::endl;
const ck::index_t BatchCount = G0 * G1;
const ck::index_t BatchCount = G0 * G1Q;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // Q layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // Q layout [G0, G1Q, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // K layout [G0, N, G2, K]
: std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // K layout [G0, G2, N, K]
? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1} // K layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // K layout [G0, G1KV, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // V layout [G0, N, G2, O]
: std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // V layout [G0, G2, N, O]
? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O} // V layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // V layout [G0, G1KV, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // Y layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // Y layout [G0, G1Q, M, O]
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // D layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // D layout [G0, G1, M, N]
? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // D layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // D layout [G0, G1Q, M, N]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1Q, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // KGrad layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // KGrad layout [G0, G1, N, K]
? std::vector<ck::index_t>{N * G1Q * K, K, G1Q * K, 1} // KGrad layout [G0, N, G1Q, K]
: std::vector<ck::index_t>{G1Q * N * K, N * K, K, 1}; // KGrad layout [G0, G1Q, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1Q, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // VGrad layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // VGrad layout [G0, G1, N, O]
? std::vector<ck::index_t>{N * G1Q * O, O, 1, G1Q * O} // VGrad layout [G0, N, G1Q, O]
: std::vector<ck::index_t>{G1Q * N * O, N * O, 1, O}; // VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward pass
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
......@@ -467,7 +467,7 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1q, m, o]
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// dO dot O = [0; 1; 2; ...]
break;
......@@ -475,7 +475,7 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1q, m, o]
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
......@@ -489,7 +489,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1}); // dy[g0,g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1q, m, o]
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
......@@ -651,7 +652,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> q_g_m_k({BatchCount, M, K});
Tensor<InputDataType> k_g_n_k({BatchCount, N, K});
Tensor<Acc0BiasDataType> d0_g_m_n({G0 * G1, M, N});
Tensor<Acc0BiasDataType> d0_g_m_n({BatchCount, M, N});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<InputDataType> v_g_n_o({BatchCount, N, O});
Tensor<AccDataType> s_g_m_n({BatchCount, M, N});
......@@ -662,27 +663,27 @@ int run(int argc, char* argv[])
z_device_buf.FromDevice(z_gs_ms_ns.mData.data());
z_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
z_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
q_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
k_g_n_k.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = k_gs_ns_ks(g0, g2, idx[1], idx[2]);
self(idx) = k_gs_ns_ks(g0, g1kv, idx[1], idx[2]);
});
v_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = v_gs_os_ns(g0, g2, idx[2], idx[1]);
self(idx) = v_gs_os_ns(g0, g1kv, idx[2], idx[1]);
});
d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
d0_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
// run fwd again for y, cause z_g_m_n update
run_attention_fwd_host(q_g_m_k,
......@@ -699,10 +700,10 @@ int run(int argc, char* argv[])
p_dropout_in_uint8_t,
rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
self(idx) = y_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]);
});
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1Q + idx[1], idx[2]); });
y_device_buf.ToDevice(y_gs_ms_os.mData.data());
lse_device_buf.ToDevice(lse_gs_ms.mData.data());
......@@ -720,7 +721,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> ygrad_dot_y_g_m({BatchCount, M});
ygrad_gs_ms_os.ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
ygrad_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
#if PRINT_HOST
......@@ -844,35 +845,35 @@ int run(int argc, char* argv[])
// permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
});
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
});
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
});
d0grad_gs_ms_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = sgrad_g_m_n(g, idx[2], idx[3]);
});
......
......@@ -271,11 +271,11 @@ int run(int argc, char* argv[])
// Overall QKV matrices shape
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_g1q_m_o = reshape(y_g_m_o, [G0, G1Q, M, O])
// y_g0_m_g1q_o = permute(y_g0_g1q_m_o, [0, 2, 1, 3])
float alpha = 1.f / std::sqrt(DIM);
float p_drop = 0.0;
int h_ratio = 1; // G1 / G2
int h_ratio = 1; // G1Q / G1KV
bool input_permute = true;
bool output_permute = true;
......@@ -379,67 +379,71 @@ int run(int argc, char* argv[])
std::size_t flop = 0, num_byte = 0;
for(std::size_t i = 0; i < group_count; i++)
{
int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM;
int O = DIM;
int G0 = rand() % 4 + 1;
int G2 = rand() % 4 + 1;
int G1 = G2 * h_ratio;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM;
int O = DIM;
int G0 = rand() % 4 + 1;
int G1KV = rand() % 4 + 1;
int G1Q = G1KV * h_ratio;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // Q layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // Q layout [G0, G1, M, K]
? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // Q layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // Q layout [G0, G1Q, M, K]
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> k_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> k_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // K layout [G0, N, G2, K]
: std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // K layout [G0, G2, N, K]
? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1}
// K layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // K layout [G0, G1KV, N, K]
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> v_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> v_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // V layout [G0, N, G2, O]
: std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // V layout [G0, G2, N, O]
? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O}
// V layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // V layout [G0, G1KV, N, O]
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> y_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> y_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // Y layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // Y layout [G0, G1, M, O]
? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // Y layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // Y layout [G0, G1Q, M, O]
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // d0 layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // d0 layout [G0, G1, M, N]
? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // d0 layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // d0 layout [G0, G1Q, M, N]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1, N, K};
? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> kgrad_gs_ns_ks_lengths{G0, G1Q, N, K};
std::vector<ck::index_t> kgrad_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // KGrad layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // KGrad layout [G0, G1, N, K]
input_permute ? std::vector<ck::index_t>{N * G1Q * K, K, G1Q * K, 1}
// KGrad layout [G0, N, G1Q, K]
: std::vector<ck::index_t>{
G1Q * N * K, N * K, K, 1}; // KGrad layout [G0, G1Q, N, K]
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_lengths{G0, G1Q, O, N};
std::vector<ck::index_t> vgrad_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // VGrad layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // VGrad layout [G0, G1, N, O]
input_permute ? std::vector<ck::index_t>{N * G1Q * O, O, 1, G1Q * O}
// VGrad layout [G0, N, G1Q, O]
: std::vector<ck::index_t>{
G1Q * N * O, N * O, 1, O}; // VGrad layout [G0, G1Q, N, O]
// The softmax stat log-sum-exp (LSE) is used to speed up softmax calculation in backward
// pass Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
// ^^^^^^^^^^^^^^^^^^^^^
// LSE
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1 * M, M, 1}; // LSE layout [G0, G1, M]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
problem_descs.push_back({
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
......@@ -463,7 +467,7 @@ int run(int argc, char* argv[])
{}, // acc1_bias_gs_ms_os_strides,
});
int BatchCount = G0 * G1;
int BatchCount = G0 * G1Q;
flop += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte +=
......@@ -532,7 +536,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_Sequential<2>{}); // dy[g0, g1q, m, o]
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// dO dot O = [0; 1; 2; ...]
break;
......@@ -540,7 +545,8 @@ int run(int argc, char* argv[])
q_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<InputDataType>{1});
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(GeneratorTensor_Sequential<3>{}); // dy[g0, g1, m, o]
ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_Sequential<3>{}); // dy[g0, g1q, m, o]
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
......@@ -555,7 +561,7 @@ int run(int argc, char* argv[])
k_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
v_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<InputDataType>{});
ygrad_gs_ms_os.GenerateTensorValue(
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1, m, o]
GeneratorTensor_1<InputDataType>{1}); // dy[g0, g1q, m, o]
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<Acc0BiasDataType>{1});
// assume mnko = 256
// P = softmax(QK) = 0.0039 * ones
......@@ -578,24 +584,24 @@ int run(int argc, char* argv[])
Tensor<InputDataType> p_drop_g_m_n({BatchCount, M, N});
q_gs_ms_ks.ForEach([&](auto& self, auto idx) {
q_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
q_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
k_g_n_k.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = k_gs_ns_ks(g0, g2, idx[1], idx[2]);
self(idx) = k_gs_ns_ks(g0, g1kv, idx[1], idx[2]);
});
d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
d0_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
v_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = v_gs_os_ns(g0, g2, idx[2], idx[1]);
self(idx) = v_gs_os_ns(g0, g1kv, idx[2], idx[1]);
});
q_g_m_ks.push_back(q_g_m_k);
......@@ -745,11 +751,11 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++)
{
int G1 = q_tensors[i].GetLengths()[1];
int G1Q = q_tensors[i].GetLengths()[1];
// copy z matirx data form device
z_tensors_device[i]->FromDevice(z_tensors[i].mData.data());
z_tensors[i].ForEach([&](auto& self, auto idx) {
z_g_m_ns[i](idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
z_g_m_ns[i](idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
run_attention_fwd_host(q_g_m_ks[i],
k_g_n_ks[i],
......@@ -766,11 +772,11 @@ int run(int argc, char* argv[])
rp_dropout);
y_tensors[i].ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_os[i](idx[0] * G1 + idx[1], idx[2], idx[3]);
self(idx) = y_g_m_os[i](idx[0] * G1Q + idx[1], idx[2], idx[3]);
});
y_tensors_device[i]->ToDevice(y_tensors[i].data());
lse_tensors[i].ForEach([&](auto& self, auto idx) {
self(idx) = lse_g_ms[i](idx[0] * G1 + idx[1], idx[2]);
self(idx) = lse_g_ms[i](idx[0] * G1Q + idx[1], idx[2]);
});
lse_tensors_device[i]->ToDevice(lse_tensors[i].data());
qgrad_tensors_device[i]->SetZero();
......@@ -785,12 +791,12 @@ int run(int argc, char* argv[])
{
int G0 = q_tensors[i].GetLengths()[0];
int G1 = q_tensors[i].GetLengths()[1];
int G1Q = q_tensors[i].GetLengths()[1];
int O = v_tensors[i].GetLengths()[2];
int N = v_tensors[i].GetLengths()[3];
int M = q_tensors[i].GetLengths()[2];
int K = q_tensors[i].GetLengths()[3];
int BatchCount = G0 * G1;
int BatchCount = G0 * G1Q;
Tensor<OutputDataType> qgrad_g_m_k({BatchCount, M, K});
Tensor<OutputDataType> kgrad_g_n_k({BatchCount, N, K});
Tensor<OutputDataType> vgrad_g_n_o({BatchCount, N, O});
......@@ -800,7 +806,7 @@ int run(int argc, char* argv[])
Tensor<InputDataType> ygrad_g_m_o({BatchCount, M, O});
ygrad_tensors[i].ForEach([&](auto& self, auto idx) {
ygrad_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
ygrad_g_m_o(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
auto ref_gemm0_grad = ReferenceGemm0GradInstance{};
auto ref_gemm0_grad_invoker = ref_gemm0_grad.MakeInvoker();
......@@ -868,34 +874,34 @@ int run(int argc, char* argv[])
vgrad_tensors_device[i]->FromDevice(vgrad_gs_os_ns_device_result.data());
// permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = qgrad_g_m_k(g, idx[2], idx[3]);
});
kgrad_gs_ns_ks_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
});
d0grad_gs_ms_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = sgrad_g_m_n(g, idx[2], idx[3]);
});
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = vgrad_g_n_o(g, idx[3], idx[2]);
});
......
......@@ -14,12 +14,12 @@ int run(int argc, char* argv[])
ck::index_t K = DIM;
ck::index_t O = DIM;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 7;
ck::index_t G1 = 12; // h_q
ck::index_t G2 = 12; // h_kv
// Output shape C[G0, M, G1Q, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1q_m_o = reshape(C_g_m_o, [g0, g1q, m, o])
// C_g0_m_g1q_o = permute(C_g0_g1q_m_o, [0, 2, 1, 3])
ck::index_t G0 = 7;
ck::index_t G1Q = 12; // h_q
ck::index_t G1KV = 12; // h_kv
bool input_permute = false;
bool output_permute = true;
......@@ -44,13 +44,13 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
G2 = std::stoi(argv[10]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1Q = std::stoi(argv[9]);
G1KV = std::stoi(argv[10]);
p_drop = std::stof(argv[11]);
......@@ -62,7 +62,7 @@ int run(int argc, char* argv[])
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 10: M, N, K, O, G0, G1, G2\n");
printf("arg4 to 10: M, N, K, O, G0, G1Q, G1KV\n");
printf("arg11: p_drop\n");
printf("arg12 to 13: input / output permute\n");
exit(0);
......@@ -73,45 +73,45 @@ int run(int argc, char* argv[])
float rp_dropout = 1.0 / p_dropout;
float alpha = 1.f / std::sqrt(K);
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // A layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // A layout [G0, G1Q, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // B0 layout [G0, N, G2, K]
: std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // B0 layout [G0, G2, N, K]
? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1} // B0 layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // B0 layout [G0, G1KV, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // B1 layout [G0, N, G2, O]
: std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // B1 layout [G0, G2, N, O]
? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O} // B1 layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // B1 layout [G0, G1KV, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // C layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // C layout [G0, G1Q, M, O]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> d_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // D layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // D layout [G0, G1, M, N]
? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // D layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // D layout [G0, G1Q, M, N]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides =
std::vector<ck::index_t>{G1 * M, M, 1}; // LSE layout [G0, G1, M]
std::vector<ck::index_t>{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
......@@ -226,7 +226,7 @@ int run(int argc, char* argv[])
return 0;
}
ck::index_t BatchCount = G0 * G1;
ck::index_t BatchCount = G0 * G1Q;
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
......@@ -314,37 +314,37 @@ int run(int argc, char* argv[])
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N});
Tensor<ADataType> a1_g_m_n_drop({BatchCount, M, N});
Tensor<LSEDataType> lse_g_m_host_result(
{BatchCount, M}); // scratch object after max + ln(sum)
Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N});
Tensor<ZDataType> z_g_m_n({G0 * G1, M, N});
Tensor<Acc0BiasDataType> d_g_m_n({BatchCount, M, N});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
// permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
a_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
b0_g_k_n.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = b0_gs_ns_ks(g0, g2, idx[2], idx[1]);
self(idx) = b0_gs_ns_ks(g0, g1kv, idx[2], idx[1]);
});
b1_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / (G1 / G2);
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = b1_gs_os_ns(g0, g2, idx[2], idx[1]);
self(idx) = b1_gs_os_ns(g0, g1kv, idx[2], idx[1]);
});
d_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
d_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
z_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
z_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
// gemm 0
......@@ -394,18 +394,18 @@ int run(int argc, char* argv[])
// permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
lse_gs_ms_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = lse_g_m_host_result(g, idx[2]);
});
......
......@@ -11,7 +11,7 @@ int run(int argc, char* argv[])
bool output_permute = true;
float p_drop = 0.2;
int h_ratio = 1; // G1 / G2
int h_ratio = 1; // G1Q / G1KV
const unsigned long long seed = 1;
const unsigned long long offset = 0;
......@@ -65,7 +65,7 @@ int run(int argc, char* argv[])
std::vector<void*> p_z; // for result verification
std::vector<void*> p_z_nullptr; // for time test
std::vector<void*> p_lse;
std::vector<std::vector<int>> g0_g1_m_n_k_o;
std::vector<std::vector<int>> g0_g1q_m_n_k_o;
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<B0DataType>> b0_tensors;
......@@ -90,55 +90,57 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++)
{
int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM;
int O = DIM;
int G0 = rand() % 3 + 1;
int G2 = rand() % 5 + 1;
int G1 = G2 * h_ratio;
int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM;
int O = DIM;
int G0 = rand() % 3 + 1;
int G1KV = rand() % 5 + 1;
int G1Q = G1KV * h_ratio;
g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O});
g0_g1q_m_n_k_o.push_back({G0, G1Q, M, N, K, O});
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // A layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // A layout [G0, G1Q, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G2, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * K, K, G2 * K, 1} // B0 layout [G0, N, G2, K]
: std::vector<ck::index_t>{G2 * N * K, N * K, K, 1}; // B0 layout [G0, G2, N, K]
? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1}
// B0 layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // B0 layout [G0, G1KV, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G2, O, N};
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G2 * O, O, 1, G2 * O} // B1 layout [G0, N, G2, O]
: std::vector<ck::index_t>{G2 * N * O, N * O, 1, O}; // B1 layout [G0, G2, N, O]
? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O}
// B1 layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // B1 layout [G0, G1KV, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // C layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // C layout [G0, G1Q, M, O]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> d_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // D layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // D layout [G0, G1, M, N]
? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // D layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // D layout [G0, G1Q, M, N]
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> z_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> z_gs_ms_ns_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // Z layout [G0, M, G1, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // Z layout [G0, G1, M, N]
? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // Z layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // Z layout [G0, G1Q, M, N]
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1, M};
std::vector<ck::index_t> lse_gs_ms_lengths{G0, G1Q, M};
std::vector<ck::index_t> lse_gs_ms_strides =
std::vector<ck::index_t>{G1 * M, M, 1}; // LSE layout [G0, G1, M]
std::vector<ck::index_t>{G1Q * M, M, 1}; // LSE layout [G0, G1Q, M]
problem_descs.push_back({a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
......@@ -166,7 +168,7 @@ int run(int argc, char* argv[])
Tensor<ZDataType> z_gs_ms_ns(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
Tensor<LSEDataType> lse_gs_ms_device_result(lse_gs_ms_lengths, lse_gs_ms_strides);
int Batch = G0 * G1;
int Batch = G0 * G1Q;
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
num_byte +=
(sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O +
......@@ -356,12 +358,12 @@ int run(int argc, char* argv[])
for(std::size_t i = 0; i < group_count; i++)
{
const int& G0 = g0_g1_m_n_k_o[i][0];
const int& G1 = g0_g1_m_n_k_o[i][1];
const int& M = g0_g1_m_n_k_o[i][2];
const int& N = g0_g1_m_n_k_o[i][3];
const int& K = g0_g1_m_n_k_o[i][4];
const int& O = g0_g1_m_n_k_o[i][5];
const int& G0 = g0_g1q_m_n_k_o[i][0];
const int& G1Q = g0_g1q_m_n_k_o[i][1];
const int& M = g0_g1q_m_n_k_o[i][2];
const int& N = g0_g1q_m_n_k_o[i][3];
const int& K = g0_g1q_m_n_k_o[i][4];
const int& O = g0_g1q_m_n_k_o[i][5];
const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths;
const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides;
......@@ -383,43 +385,43 @@ int run(int argc, char* argv[])
z_gs_ms_ns_device_buf.FromDevice(z_gs_ms_ns_device_result.mData.data());
lse_gs_ms_device_buf.FromDevice(lse_gs_ms_device_result.mData.data());
Tensor<ADataType> a_g_m_k({G0 * G1, M, K});
Tensor<B0DataType> b0_g_k_n({G0 * G1, K, N});
Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O});
Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0
Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1, M, N});
Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1
Tensor<ADataType> a_g_m_k({G0 * G1Q, M, K});
Tensor<B0DataType> b0_g_k_n({G0 * G1Q, K, N});
Tensor<B1DataType> b1_g_n_o({G0 * G1Q, N, O});
Tensor<AccDataType> acc0_g_m_n({G0 * G1Q, M, N}); // scratch object after gemm0
Tensor<Acc0BiasDataType> d_g_m_n({G0 * G1Q, M, N});
Tensor<ADataType> a1_g_m_n({G0 * G1Q, M, N}); // scratch object after softmax
Tensor<ADataType> a1_g_m_n_drop({G0 * G1Q, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({G0 * G1Q, M, O}); // scratch object after gemm1
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
Tensor<ZDataType> z_g_m_n({G0 * G1, M, N});
Tensor<LSEDataType> lse_g_m_host_result({G0 * G1, M}); // scratch object after gemm1
Tensor<ZDataType> z_g_m_n({G0 * G1Q, M, N});
Tensor<LSEDataType> lse_g_m_host_result({G0 * G1Q, M}); // scratch object after gemm1
Tensor<LSEDataType> lse_gs_ms_host_result(lse_gs_ms_lengths, lse_gs_ms_strides);
// permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
a_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
b0_g_k_n.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = b0_gs_ns_ks(g0, g2, idx[2], idx[1]);
self(idx) = b0_gs_ns_ks(g0, g1kv, idx[2], idx[1]);
});
b1_g_n_o.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0] / G1;
const size_t& g1 = idx[0] % G1;
const size_t& g2 = g1 / h_ratio;
const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = b1_gs_os_ns(g0, g2, idx[2], idx[1]);
self(idx) = b1_gs_os_ns(g0, g1kv, idx[2], idx[1]);
});
d_gs_ms_ns.ForEach([&](auto& self, auto idx) {
d_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
d_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
z_gs_ms_ns_device_result.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
z_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
});
// gemm 0
......@@ -473,18 +475,18 @@ int run(int argc, char* argv[])
// permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
lse_gs_ms_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t& g0 = idx[0];
const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * G1Q + g1q;
self(idx) = lse_g_m_host_result(g, idx[2]);
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment