Commit 8a6e65a3 authored by aska-0096's avatar aska-0096
Browse files

update self-attention and cross-attention

parent b62926dc
......@@ -301,6 +301,28 @@ using DeviceMHAFactory =
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256,
// Gemm 0
128, 64, 48, 8,4,
// Gemm 1
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec>
#endif
>;
......
......@@ -9,20 +9,18 @@ int run(int argc, char* argv[])
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 256;
ck::index_t N = 64;
ck::index_t K = 80;
ck::index_t O = 80;
ck::index_t q_sequence_length = 256;
ck::index_t kv_sequence_length = 64;
ck::index_t head_dim = 80;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 2;
ck::index_t G1 = 8;
// Output shape C[batch_size, q_sequence_length, head_num, head_dim]. Batch dim, outer dim,
// inner dim must match GEMM shape C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) C_g0_m_g1_o =
// permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t batch_size = 2;
ck::index_t head_num = 8;
float alpha = 1;
bool input_permute = false;
bool input_permute = true;
bool output_permute = true;
if(argc == 1)
......@@ -35,58 +33,85 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]);
q_sequence_length = std::stoi(argv[4]);
kv_sequence_length = std::stoi(argv[5]);
head_dim = std::stoi(argv[6]);
batch_size = std::stoi(argv[7]);
head_num = std::stoi(argv[8]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
alpha = std::stof(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n");
printf("arg10: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n");
printf(
"arg4 to 8: q_sequence_length, kv_sequence_length, head_dim, batch_size, head_num\n");
printf("arg9: scale (alpha)\n");
exit(0);
}
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_lengths{batch_size, head_num, q_sequence_length, head_dim};
std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
input_permute ? std::vector<ck::index_t>{q_sequence_length * head_num * head_dim,
head_dim,
head_num * head_dim,
1}
// A layout [batch_size, q_sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * q_sequence_length * head_dim,
q_sequence_length * head_dim,
head_dim,
1}; // A layout [batch_size, head_num, q_sequence_length, head_dim]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{
batch_size, head_num, kv_sequence_length, head_dim};
std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
input_permute ? std::vector<ck::index_t>{kv_sequence_length * head_num * head_dim,
head_dim,
head_num * head_dim,
1}
// B0 layout [batch_size, kv_sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * kv_sequence_length * head_dim,
kv_sequence_length * head_dim,
head_dim,
1}; // B0 layout [batch_size, head_num, kv_sequence_length, head_dim]
std::vector<ck::index_t> b1_gs_os_ns_lengths{
batch_size, head_num, head_dim, kv_sequence_length};
std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
? std::vector<ck::index_t>{kv_sequence_length * head_num * head_dim,
head_dim,
1,
head_num * head_dim}
// B1 layout [batch_size, kv_sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * kv_sequence_length * head_dim,
kv_sequence_length * head_dim,
1,
head_dim}; // B1 layout [batch_size, head_num, kv_sequence_length, head_dim]
std::vector<ck::index_t> c_gs_ms_os_lengths{batch_size, head_num, q_sequence_length, head_dim};
std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
output_permute ? std::vector<ck::index_t>{q_sequence_length * head_num * head_dim,
head_dim,
head_num * head_dim,
1}
// C layout [batch_size, q_sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * q_sequence_length * head_dim,
q_sequence_length * head_dim,
head_dim,
1}; // C layout [batch_size, head_num, q_sequence_length, head_dim]
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
......@@ -158,9 +183,14 @@ int run(int argc, char* argv[])
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
std::vector<ck::index_t> kv_gs_ns_ks_lengths{G0, G1, N, 2, K};
std::vector<ck::index_t> kv_gs_ns_ks_lengths{
batch_size, head_num, kv_sequence_length, 2, head_dim};
std::vector<ck::index_t> kv_gs_ns_ks_strides = std::vector<ck::index_t>{
N * G1 * 2 * K, 2 * K, G1 * 2 * K, K, 1}; // kv layout [G0, M, G1, 2, K]
kv_sequence_length * head_num * 2 * head_dim,
2 * head_dim,
head_num * 2 * head_dim,
head_dim,
1}; // kv layout [batch_size, q_sequence_length, head_num, 2, head_dim]
Tensor<ADataType> kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides);
// merge kv into a packed pointer send to device
b0_gs_ns_ks.ForEach(
......@@ -189,20 +219,20 @@ int run(int argc, char* argv[])
printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector?
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{});
const auto device_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>;
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_mha_instance)>;
auto gemm = DeviceMHAInstance{};
auto invoker = gemm.MakeCrossAttnInvoker();
auto argument =
gemm.MakeCrossAttnArgument(static_cast<ADataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(kv_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
G0,
M,
N,
G1,
K,
batch_size,
q_sequence_length,
kv_sequence_length,
head_num,
head_dim,
alpha);
// if(!gemm.IsSupportedArgument(argument))
......@@ -212,13 +242,17 @@ int run(int argc, char* argv[])
// return 0;
// }
ck::index_t BatchCount = G0 * G1;
ck::index_t BatchCount = batch_size * head_num;
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
std::size_t flop = (size_t(q_sequence_length) * kv_sequence_length * head_dim * 2 +
size_t(q_sequence_length) * kv_sequence_length * head_dim * 2) *
BatchCount;
std::size_t num_btype = (sizeof(ADataType) * q_sequence_length * head_dim +
sizeof(B0DataType) * head_dim * kv_sequence_length +
sizeof(B1DataType) * kv_sequence_length * head_dim +
sizeof(CDataType) * q_sequence_length * head_dim) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......@@ -237,22 +271,26 @@ int run(int argc, char* argv[])
{
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
Tensor<ADataType> a_g_m_k({BatchCount, M, K});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<Acc0DataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
Tensor<ADataType> a_g_m_k({BatchCount, q_sequence_length, head_dim});
Tensor<B0DataType> b0_g_k_n({BatchCount, head_dim, kv_sequence_length});
Tensor<B1DataType> b1_g_n_o({BatchCount, kv_sequence_length, head_dim});
Tensor<Acc0DataType> acc0_g_m_n(
{BatchCount, q_sequence_length, kv_sequence_length}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount,
q_sequence_length,
kv_sequence_length}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result(
{BatchCount, q_sequence_length, head_dim}); // scratch object after gemm1
// permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
a_g_m_k(idx[0] * head_num + idx[1], idx[2], idx[3]) = self(idx);
});
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
b0_g_k_n(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx);
});
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
b1_g_n_o(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx);
});
// gemm 0
......@@ -264,7 +302,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
const auto mask = typename DeviceMHAInstance::C0MatrixMask(N);
const auto mask = typename DeviceMHAInstance::C0MatrixMask(kv_sequence_length);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
......@@ -294,7 +332,7 @@ int run(int argc, char* argv[])
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * head_num + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
......@@ -330,8 +368,10 @@ int run(int argc, char* argv[])
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M
<< ", N: " << N << ", K: " << K << ", O: " << O << std::endl;
std::cout << "Problem Size: BatchCount: " << batch_size << ", HeadNum: " << head_num
<< ", q_sequence_length: " << q_sequence_length
<< ", kv_sequence_length: " << kv_sequence_length << ", head_dim: " << head_dim
<< std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
......
......@@ -9,20 +9,17 @@ int run(int argc, char* argv[])
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 256;
ck::index_t N = 256;
ck::index_t K = 80;
ck::index_t O = 80;
ck::index_t sequence_length = 256;
ck::index_t head_dim = 80;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 2;
ck::index_t G1 = 8;
// Output shape C[batch_size, sequence_length, head_num, head_dim]. Batch dim, outer dim, inner
// dim must match GEMM shape C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) C_g0_m_g1_o =
// permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t batch_size = 2;
ck::index_t head_num = 8;
float alpha = 1;
bool input_permute = false;
bool input_permute = true;
bool output_permute = true;
if(argc == 1)
......@@ -35,58 +32,81 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 13)
else if(argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]);
sequence_length = std::stoi(argv[4]);
head_dim = std::stoi(argv[5]);
batch_size = std::stoi(argv[6]);
head_num = std::stoi(argv[7]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
alpha = std::stof(argv[8]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n");
printf("arg10: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n");
printf("arg4 to 7: sequence_length, head_dim, batch_size, head_num\n");
printf("arg8: scale (alpha)\n");
exit(0);
}
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_lengths{batch_size, head_num, sequence_length, head_dim};
std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
input_permute ? std::vector<ck::index_t>{sequence_length * head_num * head_dim,
head_dim,
head_num * head_dim,
1}
// A layout [batch_size, sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * sequence_length * head_dim,
sequence_length * head_dim,
head_dim,
1}; // A layout [batch_size, head_num, sequence_length, head_dim]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{batch_size, head_num, sequence_length, head_dim};
std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K]
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
input_permute ? std::vector<ck::index_t>{sequence_length * head_num * head_dim,
head_dim,
head_num * head_dim,
1}
// B0 layout [batch_size, sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * sequence_length * head_dim,
sequence_length * head_dim,
head_dim,
1}; // B0 layout [batch_size, head_num, sequence_length, head_dim]
std::vector<ck::index_t> b1_gs_os_ns_lengths{batch_size, head_num, head_dim, sequence_length};
std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O]
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
? std::vector<ck::index_t>{sequence_length * head_num * head_dim,
head_dim,
1,
head_num * head_dim}
// B1 layout [batch_size, sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * sequence_length * head_dim,
sequence_length * head_dim,
1,
head_dim}; // B1 layout [batch_size, head_num, sequence_length, head_dim]
std::vector<ck::index_t> c_gs_ms_os_lengths{batch_size, head_num, sequence_length, head_dim};
std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
output_permute ? std::vector<ck::index_t>{sequence_length * head_num * head_dim,
head_dim,
head_num * head_dim,
1}
// C layout [batch_size, sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * sequence_length * head_dim,
sequence_length * head_dim,
head_dim,
1}; // C layout [batch_size, head_num, sequence_length, head_dim]
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
......@@ -158,9 +178,14 @@ int run(int argc, char* argv[])
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
std::vector<ck::index_t> qkv_gs_ms_ks_lengths{G0, G1, M, 3, K};
std::vector<ck::index_t> qkv_gs_ms_ks_lengths{
batch_size, head_num, sequence_length, 3, head_dim};
std::vector<ck::index_t> qkv_gs_ms_ks_strides = std::vector<ck::index_t>{
M * G1 * 3 * K, 3 * K, G1 * 3 * K, K, 1}; // qkv layout [G0, M, G1, 3, K]
sequence_length * head_num * 3 * head_dim,
3 * head_dim,
head_num * 3 * head_dim,
head_dim,
1}; // qkv layout [batch_size, sequence_length, head_num, 3, head_dim]
Tensor<ADataType> qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides);
// merge qkv into a packed pointer send to device
a_gs_ms_ks.ForEach(
......@@ -198,10 +223,10 @@ int run(int argc, char* argv[])
auto argument =
gemm.MakeSelfAttnArgument(static_cast<ADataType*>(qkv_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
G0,
M,
G1,
K,
batch_size,
sequence_length,
head_num,
head_dim,
alpha);
// if(!gemm.IsSupportedArgument(argument))
......@@ -211,13 +236,17 @@ int run(int argc, char* argv[])
// return 0;
// }
ck::index_t BatchCount = G0 * G1;
ck::index_t BatchCount = batch_size * head_num;
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
std::size_t flop = (size_t(sequence_length) * sequence_length * head_dim * 2 +
size_t(sequence_length) * sequence_length * head_dim * 2) *
BatchCount;
std::size_t num_btype = (sizeof(ADataType) * sequence_length * head_dim +
sizeof(B0DataType) * head_dim * sequence_length +
sizeof(B1DataType) * sequence_length * head_dim +
sizeof(CDataType) * sequence_length * head_dim) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......@@ -236,22 +265,25 @@ int run(int argc, char* argv[])
{
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
Tensor<ADataType> a_g_m_k({BatchCount, M, K});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O});
Tensor<Acc0DataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1
Tensor<ADataType> a_g_m_k({BatchCount, sequence_length, head_dim});
Tensor<B0DataType> b0_g_k_n({BatchCount, head_dim, sequence_length});
Tensor<B1DataType> b1_g_n_o({BatchCount, sequence_length, head_dim});
Tensor<Acc0DataType> acc0_g_m_n(
{BatchCount, sequence_length, sequence_length}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n(
{BatchCount, sequence_length, sequence_length}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result(
{BatchCount, sequence_length, head_dim}); // scratch object after gemm1
// permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
a_g_m_k(idx[0] * head_num + idx[1], idx[2], idx[3]) = self(idx);
});
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
b0_g_k_n(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx);
});
b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
b1_g_n_o(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx);
});
// gemm 0
......@@ -263,7 +295,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking
const auto mask = typename DeviceMHAInstance::C0MatrixMask(N);
const auto mask = typename DeviceMHAInstance::C0MatrixMask(sequence_length);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity();
......@@ -293,7 +325,7 @@ int run(int argc, char* argv[])
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
const size_t g = g0 * head_num + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
});
......@@ -329,8 +361,9 @@ int run(int argc, char* argv[])
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M
<< ", N: " << N << ", K: " << K << ", O: " << O << std::endl;
std::cout << "Problem Size: BatchCount: " << batch_size << ", HeadNum: " << head_num
<< ", sequence_length: " << sequence_length << ", head_dim: " << head_dim
<< std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
......
......@@ -83,12 +83,12 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32,
// Gemm 0
16, 128, 64, 8, 8,
16, 32, 160, 8, 8,
// Gemm 1
64, 64, 8,
80, 32, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
1, 2, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
......@@ -105,12 +105,34 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32,
// Gemm 0
16, 64, 64, 8, 8,
16, 64, 80, 8, 8,
// Gemm 1
64, 64, 8,
80, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
1, 4, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32,
// Gemm 0
16, 64, 48, 8, 8,
// Gemm 1
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
......@@ -129,16 +151,16 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 128, 64, 8, 8,
32, 64, 48, 8, 8,
// Gemm 1
64, 64, 8,
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
1, 4, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
......@@ -151,16 +173,38 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 64, 64, 8, 8,
32, 64, 80, 8, 8,
// Gemm 1
64, 64, 8,
80, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
1, 4, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 32, 160, 8, 8,
// Gemm 1
80, 32, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 2, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
......@@ -175,16 +219,16 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128,
// Gemm 0
64, 128, 64, 8, 8,
64, 128, 80, 8, 8,
// Gemm 1
64, 64, 8,
80, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
1, 8, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
......@@ -197,45 +241,45 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128,
// Gemm 0
64, 64, 64, 8, 8,
64, 192, 48, 8, 8,
// Gemm 1
64, 64, 8,
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
1, 12, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256,
128,
// Gemm 0
128, 128, 64, 8, 8,
64, 64, 48, 8, 8,
// Gemm 1
64, 64, 8,
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
1, 4, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
......@@ -243,16 +287,16 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256,
// Gemm 0
128, 128, 64, 8, 8,
128, 192, 48, 8,4,
// Gemm 1
64, 64, 8,
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
1, 12, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment