Commit c2510944 authored by danyao12's avatar danyao12
Browse files

mqa/gqa inference

parent 5ff2d646
...@@ -14,11 +14,12 @@ int run(int argc, char* argv[]) ...@@ -14,11 +14,12 @@ int run(int argc, char* argv[])
ck::index_t K = DIM; ck::index_t K = DIM;
ck::index_t O = DIM; ck::index_t O = DIM;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape // Output shape C[G0, M, G1Q, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) // C_g0_g1q_m_o = reshape(C_g_m_o, [g0, g1q, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) // C_g0_m_g1q_o = permute(C_g0_g1q_m_o, [0, 2, 1, 3])
ck::index_t G0 = 7; ck::index_t G0 = 7;
ck::index_t G1 = 13; ck::index_t G1Q = 12; // h_q
ck::index_t G1KV = 12; // h_kv
float alpha = 1; float alpha = 1;
...@@ -35,64 +36,65 @@ int run(int argc, char* argv[]) ...@@ -35,64 +36,65 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 14)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); M = std::stoi(argv[4]);
N = std::stoi(argv[5]); N = std::stoi(argv[5]);
K = std::stoi(argv[6]); K = std::stoi(argv[6]);
O = std::stoi(argv[7]); O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1Q = std::stoi(argv[9]);
G1KV = std::stoi(argv[10]);
alpha = std::stof(argv[10]); alpha = std::stof(argv[11]);
input_permute = std::stoi(argv[11]); input_permute = std::stoi(argv[12]);
output_permute = std::stoi(argv[12]); output_permute = std::stoi(argv[13]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4 to 10: M, N, K, O, G0, G1Q, G1KV\n");
printf("arg10: scale (alpha)\n"); printf("arg11: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n"); printf("arg12 to 13: input / output permute\n");
exit(0); exit(0);
} }
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides = std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // A layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] : std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // A layout [G0, G1Q, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides = std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1} // B0 layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] : std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // B0 layout [G0, G1KV, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides = std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O} // B1 layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] : std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // B1 layout [G0, G1KV, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides = std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // C layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] : std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // C layout [G0, G1Q, M, O]
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_strides = std::vector<ck::index_t> d0_gs_ms_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // D0 layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // D0 layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // D0 layout [G0, G1, M, N] : std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // D0 layout [G0, G1Q, M, N]
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
...@@ -188,14 +190,15 @@ int run(int argc, char* argv[]) ...@@ -188,14 +190,15 @@ int run(int argc, char* argv[])
return 0; return 0;
} }
ck::index_t BatchCount = G0 * G1; ck::index_t BatchCount = G0 * G1Q;
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2 + size_t(M) * N) * BatchCount; std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2 + size_t(M) * N) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + std::size_t num_btype =
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + sizeof(Acc0BiasDataType) * M * N) * (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O +
BatchCount; sizeof(CDataType) * M * O + sizeof(Acc0BiasDataType) * M * N) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -211,23 +214,31 @@ int run(int argc, char* argv[]) ...@@ -211,23 +214,31 @@ int run(int argc, char* argv[])
Tensor<ADataType> a_g_m_k({BatchCount, M, K}); Tensor<ADataType> a_g_m_k({BatchCount, M, K});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N}); Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O}); Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 Tensor<AccDataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<Acc0BiasDataType> d0_g_m_n({BatchCount, M, N}); Tensor<Acc0BiasDataType> d0_g_m_n({BatchCount, M, N});
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
// permute // permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) { a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); a_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { b0_g_k_n.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = b0_gs_ns_ks(g0, g1kv, idx[2], idx[1]);
}); });
d0_gs_ms_ns.ForEach([&](auto& self, auto idx) { b1_g_n_o.ForEach([&](auto& self, auto idx) {
d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / (G1Q / G1KV);
self(idx) = b1_gs_os_ns(g0, g1kv, idx[2], idx[1]);
}); });
b1_gs_os_ns.ForEach([&](auto& self, auto idx) { d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); d0_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
// gemm 0 // gemm 0
...@@ -267,10 +278,10 @@ int run(int argc, char* argv[]) ...@@ -267,10 +278,10 @@ int run(int argc, char* argv[])
// permute // permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
}); });
......
...@@ -10,6 +10,8 @@ int run(int argc, char* argv[]) ...@@ -10,6 +10,8 @@ int run(int argc, char* argv[])
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = true;
int h_ratio = 1; // G1Q / G1KV
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
...@@ -20,21 +22,23 @@ int run(int argc, char* argv[]) ...@@ -20,21 +22,23 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 6) else if(argc == 7)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
input_permute = std::stoi(argv[4]); h_ratio = std::stof(argv[4]);
output_permute = std::stoi(argv[5]); input_permute = std::stoi(argv[5]);
output_permute = std::stoi(argv[6]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 5: input / output permute\n"); printf("arg4: h_ratio\n");
printf("arg5 to 6: input / output permute\n");
exit(0); exit(0);
} }
...@@ -49,7 +53,7 @@ int run(int argc, char* argv[]) ...@@ -49,7 +53,7 @@ int run(int argc, char* argv[])
std::vector<const void*> p_d0; std::vector<const void*> p_d0;
std::vector<const void*> p_b1; std::vector<const void*> p_b1;
std::vector<void*> p_c; std::vector<void*> p_c;
std::vector<std::vector<int>> g0_g1_m_n_k_o; std::vector<std::vector<int>> g0_g1q_m_n_k_o;
std::vector<Tensor<ADataType>> a_tensors; std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<B0DataType>> b0_tensors; std::vector<Tensor<B0DataType>> b0_tensors;
...@@ -69,44 +73,47 @@ int run(int argc, char* argv[]) ...@@ -69,44 +73,47 @@ int run(int argc, char* argv[])
std::cout << "group count " << group_count << ". printing first 4 groups\n"; std::cout << "group count " << group_count << ". printing first 4 groups\n";
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int M = 128 * (rand() % 8 + 1); int M = 128 * (rand() % 8 + 1);
int N = 128 * (rand() % 8 + 1); int N = 128 * (rand() % 8 + 1);
int K = 40; int K = 40;
int O = 40 * (rand() % 2 + 1); int O = 40 * (rand() % 2 + 1);
int G0 = rand() % 3 + 1; int G0 = rand() % 3 + 1;
int G1 = rand() % 5 + 1; int G1KV = rand() % 5 + 1;
int G1Q = G1KV * h_ratio;
g0_g1_m_n_k_o.push_back({G0, G1, M, N, K, O}); g0_g1q_m_n_k_o.push_back({G0, G1Q, M, N, K, O});
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1Q, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides = std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] ? std::vector<ck::index_t>{M * G1Q * K, K, G1Q * K, 1} // A layout [G0, M, G1Q, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] : std::vector<ck::index_t>{G1Q * M * K, M * K, K, 1}; // A layout [G0, G1Q, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K}; std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1KV, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides = std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] ? std::vector<ck::index_t>{N * G1KV * K, K, G1KV * K, 1}
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] // B0 layout [G0, N, G1KV, K]
: std::vector<ck::index_t>{G1KV * N * K, N * K, K, 1}; // B0 layout [G0, G1KV, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N}; std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1KV, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides = std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] ? std::vector<ck::index_t>{N * G1KV * O, O, 1, G1KV * O}
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] // B1 layout [G0, N, G1KV, O]
: std::vector<ck::index_t>{G1KV * N * O, N * O, 1, O}; // B1 layout [G0, G1KV, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O}; std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1Q, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides = std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] ? std::vector<ck::index_t>{M * G1Q * O, O, G1Q * O, 1} // C layout [G0, M, G1Q, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] : std::vector<ck::index_t>{G1Q * M * O, M * O, O, 1}; // C layout [G0, G1Q, M, O]
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N}; std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1Q, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_strides = std::vector<ck::index_t> d0_gs_ms_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{M * G1 * N, N, G1 * N, 1} // d0 layout [G0, M, G1, N] ? std::vector<ck::index_t>{M * G1Q * N, N, G1Q * N, 1} // d0 layout [G0, M, G1Q, N]
: std::vector<ck::index_t>{G1 * M * N, M * N, N, 1}; // d0 layout [G0, G1, M, N] : std::vector<ck::index_t>{G1Q * M * N, M * N, N, 1}; // d0 layout [G0, G1Q, M, N]
problem_descs.push_back({a_gs_ms_ks_lengths, problem_descs.push_back({a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
...@@ -128,7 +135,7 @@ int run(int argc, char* argv[]) ...@@ -128,7 +135,7 @@ int run(int argc, char* argv[])
Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); Tensor<B1DataType> b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
int Batch = G0 * G1; int Batch = G0 * G1Q;
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2 + size_t(M) * N) * Batch; flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2 + size_t(M) * N) * Batch;
num_byte += num_byte +=
(sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O + (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O +
...@@ -248,12 +255,12 @@ int run(int argc, char* argv[]) ...@@ -248,12 +255,12 @@ int run(int argc, char* argv[])
{ {
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
const int& G0 = g0_g1_m_n_k_o[i][0]; const int& G0 = g0_g1q_m_n_k_o[i][0];
const int& G1 = g0_g1_m_n_k_o[i][1]; const int& G1Q = g0_g1q_m_n_k_o[i][1];
const int& M = g0_g1_m_n_k_o[i][2]; const int& M = g0_g1q_m_n_k_o[i][2];
const int& N = g0_g1_m_n_k_o[i][3]; const int& N = g0_g1q_m_n_k_o[i][3];
const int& K = g0_g1_m_n_k_o[i][4]; const int& K = g0_g1q_m_n_k_o[i][4];
const int& O = g0_g1_m_n_k_o[i][5]; const int& O = g0_g1q_m_n_k_o[i][5];
const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths; const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths;
const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides; const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides;
...@@ -267,27 +274,35 @@ int run(int argc, char* argv[]) ...@@ -267,27 +274,35 @@ int run(int argc, char* argv[])
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
Tensor<ADataType> a_g_m_k({G0 * G1, M, K}); Tensor<ADataType> a_g_m_k({G0 * G1Q, M, K});
Tensor<B0DataType> b0_g_k_n({G0 * G1, K, N}); Tensor<B0DataType> b0_g_k_n({G0 * G1Q, K, N});
Tensor<Acc0BiasDataType> d0_g_m_n({G0 * G1, M, N}); Tensor<Acc0BiasDataType> d0_g_m_n({G0 * G1Q, M, N});
Tensor<B1DataType> b1_g_n_o({G0 * G1, N, O}); Tensor<B1DataType> b1_g_n_o({G0 * G1Q, N, O});
Tensor<AccDataType> acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0 Tensor<AccDataType> acc0_g_m_n({G0 * G1Q, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax Tensor<ADataType> a1_g_m_n({G0 * G1Q, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1 Tensor<CDataType> c_g_m_o_host_result({G0 * G1Q, M, O}); // scratch object after gemm1
Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); Tensor<CDataType> c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides);
// permute // permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) { a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); a_g_m_k(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { b0_g_k_n.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = b0_gs_ns_ks(g0, g1kv, idx[2], idx[1]);
}); });
d0_gs_ms_ns.ForEach([&](auto& self, auto idx) { b1_g_n_o.ForEach([&](auto& self, auto idx) {
d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); const size_t& g0 = idx[0] / G1Q;
const size_t& g1q = idx[0] % G1Q;
const size_t& g1kv = g1q / h_ratio;
self(idx) = b1_gs_os_ns(g0, g1kv, idx[2], idx[1]);
}); });
b1_gs_os_ns.ForEach([&](auto& self, auto idx) { d0_gs_ms_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); d0_g_m_n(idx[0] * G1Q + idx[1], idx[2], idx[3]) = self(idx);
}); });
// gemm 0 // gemm 0
...@@ -331,10 +346,10 @@ int run(int argc, char* argv[]) ...@@ -331,10 +346,10 @@ int run(int argc, char* argv[])
// permute // permute
c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1q = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * G1Q + g1q;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
}); });
......
...@@ -64,6 +64,7 @@ __global__ void ...@@ -64,6 +64,7 @@ __global__ void
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const index_t h_ratio,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask) const C0MatrixMask c0_matrix_mask)
{ {
...@@ -73,13 +74,14 @@ __global__ void ...@@ -73,13 +74,14 @@ __global__ void
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(gkv_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(gkv_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
...@@ -130,6 +132,7 @@ __global__ void ...@@ -130,6 +132,7 @@ __global__ void
ignore = d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; ignore = d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = h_ratio;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask; ignore = c0_matrix_mask;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
...@@ -512,7 +515,8 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -512,7 +515,8 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]}, b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]},
c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
batch_count_{c1_grid_desc_g_m_n_.GetLength(I0)} batch_count_{c1_grid_desc_g_m_n_.GetLength(I0)},
h_ratio_{c1_grid_desc_g_m_n_.GetLength(I0) / b_grid_desc_g_n_k_.GetLength(I0)}
{ {
// TODO ANT: implement bias addition // TODO ANT: implement bias addition
ignore = p_acc1_bias; ignore = p_acc1_bias;
...@@ -613,6 +617,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -613,6 +617,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
std::vector<ck::index_t> d0s_nl_ns_lengths_strides_; std::vector<ck::index_t> d0s_nl_ns_lengths_strides_;
index_t batch_count_; index_t batch_count_;
index_t h_ratio_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// raw data // raw data
...@@ -683,6 +688,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -683,6 +688,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
arg.d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.h_ratio_,
arg.compute_base_ptr_of_batch_, arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_); arg.c0_matrix_mask_);
}; };
...@@ -730,12 +736,13 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -730,12 +736,13 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c1_grid_desc_g_m_n_.GetLength(I0); // unpadded const index_t c_g = arg.c1_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t b_g = arg.b_grid_desc_g_n_k_.GetLength(I0);
const index_t c_m = arg.c1_grid_desc_m_n_.GetLength(I0); const index_t c_m = arg.c1_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = arg.c1_grid_desc_m_n_.GetLength(I1); const index_t c_gemm1n = arg.c1_grid_desc_m_n_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0))
{ {
return false; return false;
} }
......
...@@ -39,6 +39,7 @@ __global__ void ...@@ -39,6 +39,7 @@ __global__ void
kernel_grouped_multiple_head_flash_attention_infer( kernel_grouped_multiple_head_flash_attention_infer(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count, const index_t group_count,
const index_t h_ratio,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op, const AccElementwiseOperation acc_element_op,
...@@ -76,13 +77,14 @@ __global__ void ...@@ -76,13 +77,14 @@ __global__ void
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane( const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch); (block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch);
const index_t gkv_idx = __builtin_amdgcn_readfirstlane(g_idx / h_ratio);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(gkv_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(gkv_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
...@@ -118,6 +120,7 @@ __global__ void ...@@ -118,6 +120,7 @@ __global__ void
#else #else
ignore = group_kernel_args; ignore = group_kernel_args;
ignore = group_count; ignore = group_count;
ignore = h_ratio;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
ignore = acc_element_op; ignore = acc_element_op;
...@@ -495,6 +498,8 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -495,6 +498,8 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
// for gridwise gemm check // for gridwise gemm check
C1GridDesc_M_N c1_grid_desc_m_n_; C1GridDesc_M_N c1_grid_desc_m_n_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
C1GridDesc_G_M_N c1_grid_desc_g_m_n_;
// raw data // raw data
std::vector<ck::index_t> d0_n_length_stride_; std::vector<ck::index_t> d0_n_length_stride_;
...@@ -536,6 +541,9 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -536,6 +541,9 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
grid_size_ = 0; grid_size_ = 0;
h_ratio_ = problem_desc_vec[0].a_gs_ms_ks_lengths[NumDimG - 1] /
problem_desc_vec[0].b0_gs_ns_ks_lengths[NumDimG - 1];
for(std::size_t i = 0; i < group_count_; i++) for(std::size_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]); const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
...@@ -648,6 +656,8 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -648,6 +656,8 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
{problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1], {problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]}, problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_m_n, c_grid_desc_m_n,
b_grid_desc_g_n_k,
c1_grid_desc_g_m_n,
d0_n_length_stride}); d0_n_length_stride});
} }
} }
...@@ -663,6 +673,8 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -663,6 +673,8 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
AccElementwiseOperation acc_element_op_; AccElementwiseOperation acc_element_op_;
B1ElementwiseOperation b1_element_op_; B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t h_ratio_;
}; };
// Invoker // Invoker
...@@ -739,6 +751,7 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -739,6 +751,7 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
0, 0,
cast_pointer_to_constant_address_space(arg.p_workspace_), cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.group_count_, arg.group_count_,
arg.h_ratio_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.acc_element_op_, arg.acc_element_op_,
...@@ -797,11 +810,14 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -797,11 +810,14 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
const auto& device_arg = arg.group_device_args_[i]; const auto& device_arg = arg.group_device_args_[i];
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = device_arg.c1_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t b_g = device_arg.b_grid_desc_g_n_k_.GetLength(I0);
const index_t c_m = device_arg.c1_grid_desc_m_n_.GetLength(I0); const index_t c_m = device_arg.c1_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = device_arg.c1_grid_desc_m_n_.GetLength(I1); const index_t c_gemm1n = device_arg.c1_grid_desc_m_n_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_m == a_m && c_gemm1n == b1_gemm1n && c_g % b_g == 0 &&
c_g / b_g == arg.h_ratio_))
{ {
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment