Commit 8a6e65a3 authored by aska-0096's avatar aska-0096
Browse files

update self-attention and cross-attention

parent b62926dc
...@@ -301,6 +301,28 @@ using DeviceMHAFactory = ...@@ -301,6 +301,28 @@ using DeviceMHAFactory =
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8, 1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256,
// Gemm 0
128, 64, 48, 8,4,
// Gemm 1
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec> MaskingSpec>
#endif #endif
>; >;
......
...@@ -9,20 +9,18 @@ int run(int argc, char* argv[]) ...@@ -9,20 +9,18 @@ int run(int argc, char* argv[])
// GEMM shape for A/B0/B1/C // GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 256; ck::index_t q_sequence_length = 256;
ck::index_t N = 64; ck::index_t kv_sequence_length = 64;
ck::index_t K = 80; ck::index_t head_dim = 80;
ck::index_t O = 80;
// Output shape C[batch_size, q_sequence_length, head_num, head_dim]. Batch dim, outer dim,
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape // inner dim must match GEMM shape C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) C_g0_m_g1_o =
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) // permute(C_g0_g1_m_o, [0, 2, 1, 3])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) ck::index_t batch_size = 2;
ck::index_t G0 = 2; ck::index_t head_num = 8;
ck::index_t G1 = 8;
float alpha = 1;
float alpha = 1; bool input_permute = true;
bool input_permute = false;
bool output_permute = true; bool output_permute = true;
if(argc == 1) if(argc == 1)
...@@ -35,58 +33,85 @@ int run(int argc, char* argv[]) ...@@ -35,58 +33,85 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 10)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); q_sequence_length = std::stoi(argv[4]);
N = std::stoi(argv[5]); kv_sequence_length = std::stoi(argv[5]);
K = std::stoi(argv[6]); head_dim = std::stoi(argv[6]);
O = std::stoi(argv[7]); batch_size = std::stoi(argv[7]);
G0 = std::stoi(argv[8]); head_num = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]); alpha = std::stof(argv[9]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf(
printf("arg10: scale (alpha)\n"); "arg4 to 8: q_sequence_length, kv_sequence_length, head_dim, batch_size, head_num\n");
printf("arg11 to 12: input / output permute\n"); printf("arg9: scale (alpha)\n");
exit(0); exit(0);
} }
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{batch_size, head_num, q_sequence_length, head_dim};
std::vector<ck::index_t> a_gs_ms_ks_strides = std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute input_permute ? std::vector<ck::index_t>{q_sequence_length * head_num * head_dim,
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] head_dim,
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] head_num * head_dim,
1}
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K}; // A layout [batch_size, q_sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * q_sequence_length * head_dim,
q_sequence_length * head_dim,
head_dim,
1}; // A layout [batch_size, head_num, q_sequence_length, head_dim]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{
batch_size, head_num, kv_sequence_length, head_dim};
std::vector<ck::index_t> b0_gs_ns_ks_strides = std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute input_permute ? std::vector<ck::index_t>{kv_sequence_length * head_num * head_dim,
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] head_dim,
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] head_num * head_dim,
1}
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N}; // B0 layout [batch_size, kv_sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * kv_sequence_length * head_dim,
kv_sequence_length * head_dim,
head_dim,
1}; // B0 layout [batch_size, head_num, kv_sequence_length, head_dim]
std::vector<ck::index_t> b1_gs_os_ns_lengths{
batch_size, head_num, head_dim, kv_sequence_length};
std::vector<ck::index_t> b1_gs_os_ns_strides = std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] ? std::vector<ck::index_t>{kv_sequence_length * head_num * head_dim,
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] head_dim,
1,
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O}; head_num * head_dim}
// B1 layout [batch_size, kv_sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * kv_sequence_length * head_dim,
kv_sequence_length * head_dim,
1,
head_dim}; // B1 layout [batch_size, head_num, kv_sequence_length, head_dim]
std::vector<ck::index_t> c_gs_ms_os_lengths{batch_size, head_num, q_sequence_length, head_dim};
std::vector<ck::index_t> c_gs_ms_os_strides = std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute output_permute ? std::vector<ck::index_t>{q_sequence_length * head_num * head_dim,
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] head_dim,
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] head_num * head_dim,
1}
// C layout [batch_size, q_sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * q_sequence_length * head_dim,
q_sequence_length * head_dim,
head_dim,
1}; // C layout [batch_size, head_num, q_sequence_length, head_dim]
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
...@@ -158,9 +183,14 @@ int run(int argc, char* argv[]) ...@@ -158,9 +183,14 @@ int run(int argc, char* argv[])
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
std::vector<ck::index_t> kv_gs_ns_ks_lengths{G0, G1, N, 2, K}; std::vector<ck::index_t> kv_gs_ns_ks_lengths{
batch_size, head_num, kv_sequence_length, 2, head_dim};
std::vector<ck::index_t> kv_gs_ns_ks_strides = std::vector<ck::index_t>{ std::vector<ck::index_t> kv_gs_ns_ks_strides = std::vector<ck::index_t>{
N * G1 * 2 * K, 2 * K, G1 * 2 * K, K, 1}; // kv layout [G0, M, G1, 2, K] kv_sequence_length * head_num * 2 * head_dim,
2 * head_dim,
head_num * 2 * head_dim,
head_dim,
1}; // kv layout [batch_size, q_sequence_length, head_num, 2, head_dim]
Tensor<ADataType> kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides); Tensor<ADataType> kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides);
// merge kv into a packed pointer send to device // merge kv into a packed pointer send to device
b0_gs_ns_ks.ForEach( b0_gs_ns_ks.ForEach(
...@@ -189,20 +219,20 @@ int run(int argc, char* argv[]) ...@@ -189,20 +219,20 @@ int run(int argc, char* argv[])
printf("Verification: %s\n", do_verification ? "ON" : "OFF"); printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector? // TODO ANT: replace array with vector?
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void { ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{}); const auto device_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>; using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_mha_instance)>;
auto gemm = DeviceMHAInstance{}; auto gemm = DeviceMHAInstance{};
auto invoker = gemm.MakeCrossAttnInvoker(); auto invoker = gemm.MakeCrossAttnInvoker();
auto argument = auto argument =
gemm.MakeCrossAttnArgument(static_cast<ADataType*>(q_device_buf.GetDeviceBuffer()), gemm.MakeCrossAttnArgument(static_cast<ADataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(kv_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(kv_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
G0, batch_size,
M, q_sequence_length,
N, kv_sequence_length,
G1, head_num,
K, head_dim,
alpha); alpha);
// if(!gemm.IsSupportedArgument(argument)) // if(!gemm.IsSupportedArgument(argument))
...@@ -212,13 +242,17 @@ int run(int argc, char* argv[]) ...@@ -212,13 +242,17 @@ int run(int argc, char* argv[])
// return 0; // return 0;
// } // }
ck::index_t BatchCount = G0 * G1; ck::index_t BatchCount = batch_size * head_num;
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; std::size_t flop = (size_t(q_sequence_length) * kv_sequence_length * head_dim * 2 +
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + size_t(q_sequence_length) * kv_sequence_length * head_dim * 2) *
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * q_sequence_length * head_dim +
sizeof(B0DataType) * head_dim * kv_sequence_length +
sizeof(B1DataType) * kv_sequence_length * head_dim +
sizeof(CDataType) * q_sequence_length * head_dim) *
BatchCount; BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -237,22 +271,26 @@ int run(int argc, char* argv[]) ...@@ -237,22 +271,26 @@ int run(int argc, char* argv[])
{ {
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
Tensor<ADataType> a_g_m_k({BatchCount, M, K}); Tensor<ADataType> a_g_m_k({BatchCount, q_sequence_length, head_dim});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N}); Tensor<B0DataType> b0_g_k_n({BatchCount, head_dim, kv_sequence_length});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O}); Tensor<B1DataType> b1_g_n_o({BatchCount, kv_sequence_length, head_dim});
Tensor<Acc0DataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 Tensor<Acc0DataType> acc0_g_m_n(
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax {BatchCount, q_sequence_length, kv_sequence_length}); // scratch object after gemm0
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 Tensor<ADataType> a1_g_m_n({BatchCount,
q_sequence_length,
kv_sequence_length}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result(
{BatchCount, q_sequence_length, head_dim}); // scratch object after gemm1
// permute // permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) { a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); a_g_m_k(idx[0] * head_num + idx[1], idx[2], idx[3]) = self(idx);
}); });
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); b0_g_k_n(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx);
}); });
b1_gs_os_ns.ForEach([&](auto& self, auto idx) { b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); b1_g_n_o(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx);
}); });
// gemm 0 // gemm 0
...@@ -264,7 +302,7 @@ int run(int argc, char* argv[]) ...@@ -264,7 +302,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
const auto mask = typename DeviceMHAInstance::C0MatrixMask(N); const auto mask = typename DeviceMHAInstance::C0MatrixMask(kv_sequence_length);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
...@@ -294,7 +332,7 @@ int run(int argc, char* argv[]) ...@@ -294,7 +332,7 @@ int run(int argc, char* argv[])
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * head_num + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
}); });
...@@ -330,8 +368,10 @@ int run(int argc, char* argv[]) ...@@ -330,8 +368,10 @@ int run(int argc, char* argv[])
std::cout << "---------------------------------------------------------------------------------" std::cout << "---------------------------------------------------------------------------------"
"-----------" "-----------"
<< std::endl; << std::endl;
std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M std::cout << "Problem Size: BatchCount: " << batch_size << ", HeadNum: " << head_num
<< ", N: " << N << ", K: " << K << ", O: " << O << std::endl; << ", q_sequence_length: " << q_sequence_length
<< ", kv_sequence_length: " << kv_sequence_length << ", head_dim: " << head_dim
<< std::endl;
std::cout << "---------------------------------------------------------------------------------" std::cout << "---------------------------------------------------------------------------------"
"-----------" "-----------"
<< std::endl; << std::endl;
......
...@@ -9,20 +9,17 @@ int run(int argc, char* argv[]) ...@@ -9,20 +9,17 @@ int run(int argc, char* argv[])
// GEMM shape for A/B0/B1/C // GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 256; ck::index_t sequence_length = 256;
ck::index_t N = 256; ck::index_t head_dim = 80;
ck::index_t K = 80;
ck::index_t O = 80;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape // Output shape C[batch_size, sequence_length, head_num, head_dim]. Batch dim, outer dim, inner
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) // dim must match GEMM shape C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) C_g0_m_g1_o =
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) // permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 2; ck::index_t batch_size = 2;
ck::index_t G1 = 8; ck::index_t head_num = 8;
float alpha = 1; float alpha = 1;
bool input_permute = true;
bool input_permute = false;
bool output_permute = true; bool output_permute = true;
if(argc == 1) if(argc == 1)
...@@ -35,58 +32,81 @@ int run(int argc, char* argv[]) ...@@ -35,58 +32,81 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 9)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); sequence_length = std::stoi(argv[4]);
N = std::stoi(argv[5]); head_dim = std::stoi(argv[5]);
K = std::stoi(argv[6]); batch_size = std::stoi(argv[6]);
O = std::stoi(argv[7]); head_num = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]);
input_permute = std::stoi(argv[11]); alpha = std::stof(argv[8]);
output_permute = std::stoi(argv[12]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4 to 7: sequence_length, head_dim, batch_size, head_num\n");
printf("arg10: scale (alpha)\n"); printf("arg8: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n");
exit(0); exit(0);
} }
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{batch_size, head_num, sequence_length, head_dim};
std::vector<ck::index_t> a_gs_ms_ks_strides = std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute input_permute ? std::vector<ck::index_t>{sequence_length * head_num * head_dim,
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] head_dim,
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] head_num * head_dim,
1}
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K}; // A layout [batch_size, sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * sequence_length * head_dim,
sequence_length * head_dim,
head_dim,
1}; // A layout [batch_size, head_num, sequence_length, head_dim]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{batch_size, head_num, sequence_length, head_dim};
std::vector<ck::index_t> b0_gs_ns_ks_strides = std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute input_permute ? std::vector<ck::index_t>{sequence_length * head_num * head_dim,
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] head_dim,
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] head_num * head_dim,
1}
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N}; // B0 layout [batch_size, sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * sequence_length * head_dim,
sequence_length * head_dim,
head_dim,
1}; // B0 layout [batch_size, head_num, sequence_length, head_dim]
std::vector<ck::index_t> b1_gs_os_ns_lengths{batch_size, head_num, head_dim, sequence_length};
std::vector<ck::index_t> b1_gs_os_ns_strides = std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] ? std::vector<ck::index_t>{sequence_length * head_num * head_dim,
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] head_dim,
1,
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O}; head_num * head_dim}
// B1 layout [batch_size, sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * sequence_length * head_dim,
sequence_length * head_dim,
1,
head_dim}; // B1 layout [batch_size, head_num, sequence_length, head_dim]
std::vector<ck::index_t> c_gs_ms_os_lengths{batch_size, head_num, sequence_length, head_dim};
std::vector<ck::index_t> c_gs_ms_os_strides = std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute output_permute ? std::vector<ck::index_t>{sequence_length * head_num * head_dim,
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] head_dim,
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] head_num * head_dim,
1}
// C layout [batch_size, sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * sequence_length * head_dim,
sequence_length * head_dim,
head_dim,
1}; // C layout [batch_size, head_num, sequence_length, head_dim]
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
...@@ -158,9 +178,14 @@ int run(int argc, char* argv[]) ...@@ -158,9 +178,14 @@ int run(int argc, char* argv[])
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
std::vector<ck::index_t> qkv_gs_ms_ks_lengths{G0, G1, M, 3, K}; std::vector<ck::index_t> qkv_gs_ms_ks_lengths{
batch_size, head_num, sequence_length, 3, head_dim};
std::vector<ck::index_t> qkv_gs_ms_ks_strides = std::vector<ck::index_t>{ std::vector<ck::index_t> qkv_gs_ms_ks_strides = std::vector<ck::index_t>{
M * G1 * 3 * K, 3 * K, G1 * 3 * K, K, 1}; // qkv layout [G0, M, G1, 3, K] sequence_length * head_num * 3 * head_dim,
3 * head_dim,
head_num * 3 * head_dim,
head_dim,
1}; // qkv layout [batch_size, sequence_length, head_num, 3, head_dim]
Tensor<ADataType> qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides); Tensor<ADataType> qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides);
// merge qkv into a packed pointer send to device // merge qkv into a packed pointer send to device
a_gs_ms_ks.ForEach( a_gs_ms_ks.ForEach(
...@@ -198,10 +223,10 @@ int run(int argc, char* argv[]) ...@@ -198,10 +223,10 @@ int run(int argc, char* argv[])
auto argument = auto argument =
gemm.MakeSelfAttnArgument(static_cast<ADataType*>(qkv_device_buf.GetDeviceBuffer()), gemm.MakeSelfAttnArgument(static_cast<ADataType*>(qkv_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
G0, batch_size,
M, sequence_length,
G1, head_num,
K, head_dim,
alpha); alpha);
// if(!gemm.IsSupportedArgument(argument)) // if(!gemm.IsSupportedArgument(argument))
...@@ -211,13 +236,17 @@ int run(int argc, char* argv[]) ...@@ -211,13 +236,17 @@ int run(int argc, char* argv[])
// return 0; // return 0;
// } // }
ck::index_t BatchCount = G0 * G1; ck::index_t BatchCount = batch_size * head_num;
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; std::size_t flop = (size_t(sequence_length) * sequence_length * head_dim * 2 +
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + size_t(sequence_length) * sequence_length * head_dim * 2) *
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * sequence_length * head_dim +
sizeof(B0DataType) * head_dim * sequence_length +
sizeof(B1DataType) * sequence_length * head_dim +
sizeof(CDataType) * sequence_length * head_dim) *
BatchCount; BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -236,22 +265,25 @@ int run(int argc, char* argv[]) ...@@ -236,22 +265,25 @@ int run(int argc, char* argv[])
{ {
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
Tensor<ADataType> a_g_m_k({BatchCount, M, K}); Tensor<ADataType> a_g_m_k({BatchCount, sequence_length, head_dim});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N}); Tensor<B0DataType> b0_g_k_n({BatchCount, head_dim, sequence_length});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O}); Tensor<B1DataType> b1_g_n_o({BatchCount, sequence_length, head_dim});
Tensor<Acc0DataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 Tensor<Acc0DataType> acc0_g_m_n(
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax {BatchCount, sequence_length, sequence_length}); // scratch object after gemm0
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 Tensor<ADataType> a1_g_m_n(
{BatchCount, sequence_length, sequence_length}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result(
{BatchCount, sequence_length, head_dim}); // scratch object after gemm1
// permute // permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) { a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); a_g_m_k(idx[0] * head_num + idx[1], idx[2], idx[3]) = self(idx);
}); });
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); b0_g_k_n(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx);
}); });
b1_gs_os_ns.ForEach([&](auto& self, auto idx) { b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); b1_g_n_o(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx);
}); });
// gemm 0 // gemm 0
...@@ -263,7 +295,7 @@ int run(int argc, char* argv[]) ...@@ -263,7 +295,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
const auto mask = typename DeviceMHAInstance::C0MatrixMask(N); const auto mask = typename DeviceMHAInstance::C0MatrixMask(sequence_length);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
...@@ -293,7 +325,7 @@ int run(int argc, char* argv[]) ...@@ -293,7 +325,7 @@ int run(int argc, char* argv[])
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * head_num + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
}); });
...@@ -329,8 +361,9 @@ int run(int argc, char* argv[]) ...@@ -329,8 +361,9 @@ int run(int argc, char* argv[])
std::cout << "---------------------------------------------------------------------------------" std::cout << "---------------------------------------------------------------------------------"
"-----------" "-----------"
<< std::endl; << std::endl;
std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M std::cout << "Problem Size: BatchCount: " << batch_size << ", HeadNum: " << head_num
<< ", N: " << N << ", K: " << K << ", O: " << O << std::endl; << ", sequence_length: " << sequence_length << ", head_dim: " << head_dim
<< std::endl;
std::cout << "---------------------------------------------------------------------------------" std::cout << "---------------------------------------------------------------------------------"
"-----------" "-----------"
<< std::endl; << std::endl;
......
...@@ -83,12 +83,34 @@ using DeviceMHAFactory = ...@@ -83,12 +83,34 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32, 32,
// Gemm 0 // Gemm 0
16, 128, 64, 8, 8, 16, 32, 160, 8, 8,
// Gemm 1 // Gemm 1
64, 64, 8, 80, 32, 8,
16, 16, 16, 16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 2, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32,
// Gemm 0
16, 64, 80, 8, 8,
// Gemm 1
80, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4, 1, 4, 5,
// ABlockTransfer MK -> K0 M K1 // ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1 // B0BlockTransfer LK -> K0 L K1
...@@ -105,12 +127,12 @@ using DeviceMHAFactory = ...@@ -105,12 +127,12 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32, 32,
// Gemm 0 // Gemm 0
16, 64, 64, 8, 8, 16, 64, 48, 8, 8,
// Gemm 1 // Gemm 1
64, 64, 8, 48, 64, 8,
16, 16, 16, 16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4, 1, 4, 3,
// ABlockTransfer MK -> K0 M K1 // ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1 // B0BlockTransfer LK -> K0 L K1
...@@ -129,16 +151,16 @@ using DeviceMHAFactory = ...@@ -129,16 +151,16 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64, 64,
// Gemm 0 // Gemm 0
32, 128, 64, 8, 8, 32, 64, 48, 8, 8,
// Gemm 1 // Gemm 1
64, 64, 8, 48, 64, 8,
16, 16, 16, 16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4, 1, 4, 3,
// ABlockTransfer MK -> K0 M K1 // ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1 // B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1 // B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
...@@ -151,16 +173,38 @@ using DeviceMHAFactory = ...@@ -151,16 +173,38 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64, 64,
// Gemm 0 // Gemm 0
32, 64, 64, 8, 8, 32, 64, 80, 8, 8,
// Gemm 1 // Gemm 1
64, 64, 8, 80, 64, 8,
16, 16, 16, 16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4, 1, 4, 5,
// ABlockTransfer MK -> K0 M K1 // ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1 // B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 32, 160, 8, 8,
// Gemm 1
80, 32, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 2, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1 // B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
...@@ -175,20 +219,20 @@ using DeviceMHAFactory = ...@@ -175,20 +219,20 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128, 128,
// Gemm 0 // Gemm 0
64, 128, 64, 8, 8, 64, 128, 80, 8, 8,
// Gemm 1 // Gemm 1
64, 64, 8, 80, 64, 8,
16, 16, 16, 16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4, 1, 8, 5,
// ABlockTransfer MK -> K0 M K1 // ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1 // B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1 // B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8, 1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>, MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
...@@ -197,45 +241,45 @@ using DeviceMHAFactory = ...@@ -197,45 +241,45 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128, 128,
// Gemm 0 // Gemm 0
64, 64, 64, 8, 8, 64, 192, 48, 8, 8,
// Gemm 1 // Gemm 1
64, 64, 8, 48, 64, 8,
16, 16, 16, 16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4, 1, 12, 3,
// ABlockTransfer MK -> K0 M K1 // ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1 // B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1 // B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8, 1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>, MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256, 128,
// Gemm 0 // Gemm 0
128, 128, 64, 8, 8, 64, 64, 48, 8, 8,
// Gemm 1 // Gemm 1
64, 64, 8, 48, 64, 8,
16, 16, 16, 16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4, 1, 4, 3,
// ABlockTransfer MK -> K0 M K1 // ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1 // B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1 // B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8, 1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>, MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
...@@ -243,18 +287,18 @@ using DeviceMHAFactory = ...@@ -243,18 +287,18 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256, 256,
// Gemm 0 // Gemm 0
128, 128, 64, 8, 8, 128, 192, 48, 8,4,
// Gemm 1 // Gemm 1
64, 64, 8, 48, 64, 8,
16, 16, 16, 16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4, 1, 12, 3,
// ABlockTransfer MK -> K0 M K1 // ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1 // B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true,
// B1BlockTransfer NL -> L0 N L1 // B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8, 1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec> MaskingSpec>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment