Commit ccaea50e authored by Jing Zhang's avatar Jing Zhang
Browse files

merge navi31_rel

parents 0b914465 10127959
...@@ -72,5 +72,18 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -72,5 +72,18 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif() endif()
endforeach() endforeach()
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
list(APPEND gpu_list gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
set(target 1)
endif()
endforeach()
...@@ -49,7 +49,7 @@ struct ProblemSizeStreamK final ...@@ -49,7 +49,7 @@ struct ProblemSizeStreamK final
struct ExecutionConfig final struct ExecutionConfig final
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 2;
bool time_kernel = false; bool time_kernel = false;
}; };
......
...@@ -20,14 +20,18 @@ using BElementOp = PassThrough; ...@@ -20,14 +20,18 @@ using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto LoopSched = ck::make_default_loop_scheduler();
static constexpr auto PipelineVer = ck::PipelineVersion::v1;
using ComputeTypeA = ck::f8_t;
using ComputeTypeB = ck::f8_t;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| Compute| Compute|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| TypeA| TypeB|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -27,10 +27,10 @@ using ComputeTypeB = ck::bf8_t; ...@@ -27,10 +27,10 @@ using ComputeTypeB = ck::bf8_t;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| Compute| Compute|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| TypeA| TypeB|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
// clang-format on // clang-format on
......
...@@ -85,8 +85,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -85,8 +85,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-2.f, 2.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-2.f, 2.f}(b_k_n);
break; break;
default: default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k); ck::utils::FillUniformDistribution<ADataType>{-0.1f, 0.1f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n); ck::utils::FillUniformDistribution<BDataType>{-0.1f, 0.1f}(b_k_n);
} }
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
...@@ -256,7 +256,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -256,7 +256,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#else #else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); return ck::utils::check_err(
c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 1e-1, 1e-1);
#endif #endif
} }
......
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") if(GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
endif() endif()
...@@ -279,9 +279,8 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[]) ...@@ -279,9 +279,8 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[])
switch(conv_param.num_dim_spatial_) switch(conv_param.num_dim_spatial_)
{ {
// case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param); // case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param);
case 2: case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param);
return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param); // case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param);
// case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param);
} }
return false; return false;
......
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") if(GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp) add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp) add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp)
......
...@@ -301,6 +301,28 @@ using DeviceMHAFactory = ...@@ -301,6 +301,28 @@ using DeviceMHAFactory =
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8, 1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256,
// Gemm 0
128, 64, 48, 8,4,
// Gemm 1
48, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 3,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec> MaskingSpec>
#endif #endif
>; >;
......
...@@ -182,9 +182,9 @@ int run(int argc, char* argv[]) ...@@ -182,9 +182,9 @@ int run(int argc, char* argv[])
printf("Verification: %s\n", do_verification ? "ON" : "OFF"); printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector? // TODO ANT: replace array with vector?
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void { ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{}); const auto device_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>; using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_mha_instance)>;
auto gemm = DeviceMHAInstance{}; auto gemm = DeviceMHAInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
......
...@@ -9,20 +9,18 @@ int run(int argc, char* argv[]) ...@@ -9,20 +9,18 @@ int run(int argc, char* argv[])
// GEMM shape for A/B0/B1/C // GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 256; ck::index_t q_sequence_length = 256;
ck::index_t N = 64; ck::index_t kv_sequence_length = 64;
ck::index_t K = 80; ck::index_t head_dim = 80;
ck::index_t O = 80;
// Output shape C[batch_size, q_sequence_length, head_num, head_dim]. Batch dim, outer dim,
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape // inner dim must match GEMM shape C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) C_g0_m_g1_o =
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) // permute(C_g0_g1_m_o, [0, 2, 1, 3])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) ck::index_t batch_size = 2;
ck::index_t G0 = 2; ck::index_t head_num = 8;
ck::index_t G1 = 8;
float alpha = 1;
float alpha = 1; bool input_permute = true;
bool input_permute = false;
bool output_permute = true; bool output_permute = true;
if(argc == 1) if(argc == 1)
...@@ -35,58 +33,85 @@ int run(int argc, char* argv[]) ...@@ -35,58 +33,85 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 10)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); q_sequence_length = std::stoi(argv[4]);
N = std::stoi(argv[5]); kv_sequence_length = std::stoi(argv[5]);
K = std::stoi(argv[6]); head_dim = std::stoi(argv[6]);
O = std::stoi(argv[7]); batch_size = std::stoi(argv[7]);
G0 = std::stoi(argv[8]); head_num = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]); alpha = std::stof(argv[9]);
input_permute = std::stoi(argv[11]);
output_permute = std::stoi(argv[12]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf(
printf("arg10: scale (alpha)\n"); "arg4 to 8: q_sequence_length, kv_sequence_length, head_dim, batch_size, head_num\n");
printf("arg11 to 12: input / output permute\n"); printf("arg9: scale (alpha)\n");
exit(0); exit(0);
} }
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{batch_size, head_num, q_sequence_length, head_dim};
std::vector<ck::index_t> a_gs_ms_ks_strides = std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute input_permute ? std::vector<ck::index_t>{q_sequence_length * head_num * head_dim,
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] head_dim,
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] head_num * head_dim,
1}
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K}; // A layout [batch_size, q_sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * q_sequence_length * head_dim,
q_sequence_length * head_dim,
head_dim,
1}; // A layout [batch_size, head_num, q_sequence_length, head_dim]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{
batch_size, head_num, kv_sequence_length, head_dim};
std::vector<ck::index_t> b0_gs_ns_ks_strides = std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute input_permute ? std::vector<ck::index_t>{kv_sequence_length * head_num * head_dim,
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] head_dim,
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] head_num * head_dim,
1}
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N}; // B0 layout [batch_size, kv_sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * kv_sequence_length * head_dim,
kv_sequence_length * head_dim,
head_dim,
1}; // B0 layout [batch_size, head_num, kv_sequence_length, head_dim]
std::vector<ck::index_t> b1_gs_os_ns_lengths{
batch_size, head_num, head_dim, kv_sequence_length};
std::vector<ck::index_t> b1_gs_os_ns_strides = std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] ? std::vector<ck::index_t>{kv_sequence_length * head_num * head_dim,
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] head_dim,
1,
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O}; head_num * head_dim}
// B1 layout [batch_size, kv_sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * kv_sequence_length * head_dim,
kv_sequence_length * head_dim,
1,
head_dim}; // B1 layout [batch_size, head_num, kv_sequence_length, head_dim]
std::vector<ck::index_t> c_gs_ms_os_lengths{batch_size, head_num, q_sequence_length, head_dim};
std::vector<ck::index_t> c_gs_ms_os_strides = std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute output_permute ? std::vector<ck::index_t>{q_sequence_length * head_num * head_dim,
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] head_dim,
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] head_num * head_dim,
1}
// C layout [batch_size, q_sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * q_sequence_length * head_dim,
q_sequence_length * head_dim,
head_dim,
1}; // C layout [batch_size, head_num, q_sequence_length, head_dim]
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
...@@ -158,9 +183,14 @@ int run(int argc, char* argv[]) ...@@ -158,9 +183,14 @@ int run(int argc, char* argv[])
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
std::vector<ck::index_t> kv_gs_ns_ks_lengths{G0, G1, N, 2, K}; std::vector<ck::index_t> kv_gs_ns_ks_lengths{
batch_size, head_num, kv_sequence_length, 2, head_dim};
std::vector<ck::index_t> kv_gs_ns_ks_strides = std::vector<ck::index_t>{ std::vector<ck::index_t> kv_gs_ns_ks_strides = std::vector<ck::index_t>{
N * G1 * 2 * K, 2 * K, G1 * 2 * K, K, 1}; // kv layout [G0, M, G1, 2, K] kv_sequence_length * head_num * 2 * head_dim,
2 * head_dim,
head_num * 2 * head_dim,
head_dim,
1}; // kv layout [batch_size, q_sequence_length, head_num, 2, head_dim]
Tensor<ADataType> kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides); Tensor<ADataType> kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides);
// merge kv into a packed pointer send to device // merge kv into a packed pointer send to device
b0_gs_ns_ks.ForEach( b0_gs_ns_ks.ForEach(
...@@ -189,20 +219,20 @@ int run(int argc, char* argv[]) ...@@ -189,20 +219,20 @@ int run(int argc, char* argv[])
printf("Verification: %s\n", do_verification ? "ON" : "OFF"); printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector? // TODO ANT: replace array with vector?
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void { ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{}); const auto device_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>; using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_mha_instance)>;
auto gemm = DeviceMHAInstance{}; auto gemm = DeviceMHAInstance{};
auto invoker = gemm.MakeCrossAttnInvoker(); auto invoker = gemm.MakeCrossAttnInvoker();
auto argument = auto argument =
gemm.MakeCrossAttnArgument(static_cast<ADataType*>(q_device_buf.GetDeviceBuffer()), gemm.MakeCrossAttnArgument(static_cast<ADataType*>(q_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(kv_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(kv_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
G0, batch_size,
M, q_sequence_length,
N, kv_sequence_length,
G1, head_num,
K, head_dim,
alpha); alpha);
// if(!gemm.IsSupportedArgument(argument)) // if(!gemm.IsSupportedArgument(argument))
...@@ -212,13 +242,17 @@ int run(int argc, char* argv[]) ...@@ -212,13 +242,17 @@ int run(int argc, char* argv[])
// return 0; // return 0;
// } // }
ck::index_t BatchCount = G0 * G1; ck::index_t BatchCount = batch_size * head_num;
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; std::size_t flop = (size_t(q_sequence_length) * kv_sequence_length * head_dim * 2 +
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + size_t(q_sequence_length) * kv_sequence_length * head_dim * 2) *
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * q_sequence_length * head_dim +
sizeof(B0DataType) * head_dim * kv_sequence_length +
sizeof(B1DataType) * kv_sequence_length * head_dim +
sizeof(CDataType) * q_sequence_length * head_dim) *
BatchCount; BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -237,22 +271,26 @@ int run(int argc, char* argv[]) ...@@ -237,22 +271,26 @@ int run(int argc, char* argv[])
{ {
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
Tensor<ADataType> a_g_m_k({BatchCount, M, K}); Tensor<ADataType> a_g_m_k({BatchCount, q_sequence_length, head_dim});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N}); Tensor<B0DataType> b0_g_k_n({BatchCount, head_dim, kv_sequence_length});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O}); Tensor<B1DataType> b1_g_n_o({BatchCount, kv_sequence_length, head_dim});
Tensor<Acc0DataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 Tensor<Acc0DataType> acc0_g_m_n(
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax {BatchCount, q_sequence_length, kv_sequence_length}); // scratch object after gemm0
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 Tensor<ADataType> a1_g_m_n({BatchCount,
q_sequence_length,
kv_sequence_length}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result(
{BatchCount, q_sequence_length, head_dim}); // scratch object after gemm1
// permute // permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) { a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); a_g_m_k(idx[0] * head_num + idx[1], idx[2], idx[3]) = self(idx);
}); });
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); b0_g_k_n(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx);
}); });
b1_gs_os_ns.ForEach([&](auto& self, auto idx) { b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); b1_g_n_o(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx);
}); });
// gemm 0 // gemm 0
...@@ -264,7 +302,7 @@ int run(int argc, char* argv[]) ...@@ -264,7 +302,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
const auto mask = typename DeviceMHAInstance::C0MatrixMask(N); const auto mask = typename DeviceMHAInstance::C0MatrixMask(kv_sequence_length);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
...@@ -294,7 +332,7 @@ int run(int argc, char* argv[]) ...@@ -294,7 +332,7 @@ int run(int argc, char* argv[])
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * head_num + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
}); });
...@@ -330,8 +368,10 @@ int run(int argc, char* argv[]) ...@@ -330,8 +368,10 @@ int run(int argc, char* argv[])
std::cout << "---------------------------------------------------------------------------------" std::cout << "---------------------------------------------------------------------------------"
"-----------" "-----------"
<< std::endl; << std::endl;
std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M std::cout << "Problem Size: BatchCount: " << batch_size << ", HeadNum: " << head_num
<< ", N: " << N << ", K: " << K << ", O: " << O << std::endl; << ", q_sequence_length: " << q_sequence_length
<< ", kv_sequence_length: " << kv_sequence_length << ", head_dim: " << head_dim
<< std::endl;
std::cout << "---------------------------------------------------------------------------------" std::cout << "---------------------------------------------------------------------------------"
"-----------" "-----------"
<< std::endl; << std::endl;
......
...@@ -185,9 +185,9 @@ int run(int argc, char* argv[]) ...@@ -185,9 +185,9 @@ int run(int argc, char* argv[])
printf("Verification: %s\n", do_verification ? "ON" : "OFF"); printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector? // TODO ANT: replace array with vector?
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void { ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{}); const auto device_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>; using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_mha_instance)>;
auto gemm = DeviceMHAInstance{}; auto gemm = DeviceMHAInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
......
...@@ -185,9 +185,9 @@ int run(int argc, char* argv[]) ...@@ -185,9 +185,9 @@ int run(int argc, char* argv[])
printf("Verification: %s\n", do_verification ? "ON" : "OFF"); printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector? // TODO ANT: replace array with vector?
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void { ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{}); const auto device_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>; using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_mha_instance)>;
auto gemm = DeviceMHAInstance{}; auto gemm = DeviceMHAInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
......
...@@ -9,20 +9,17 @@ int run(int argc, char* argv[]) ...@@ -9,20 +9,17 @@ int run(int argc, char* argv[])
// GEMM shape for A/B0/B1/C // GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 256; ck::index_t sequence_length = 256;
ck::index_t N = 256; ck::index_t head_dim = 80;
ck::index_t K = 80;
ck::index_t O = 80;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape // Output shape C[batch_size, sequence_length, head_num, head_dim]. Batch dim, outer dim, inner
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) // dim must match GEMM shape C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) C_g0_m_g1_o =
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) // permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 2; ck::index_t batch_size = 2;
ck::index_t G1 = 8; ck::index_t head_num = 8;
float alpha = 1; float alpha = 1;
bool input_permute = true;
bool input_permute = false;
bool output_permute = true; bool output_permute = true;
if(argc == 1) if(argc == 1)
...@@ -35,58 +32,81 @@ int run(int argc, char* argv[]) ...@@ -35,58 +32,81 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 13) else if(argc == 9)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]); sequence_length = std::stoi(argv[4]);
N = std::stoi(argv[5]); head_dim = std::stoi(argv[5]);
K = std::stoi(argv[6]); batch_size = std::stoi(argv[6]);
O = std::stoi(argv[7]); head_num = std::stoi(argv[7]);
G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]);
input_permute = std::stoi(argv[11]); alpha = std::stof(argv[8]);
output_permute = std::stoi(argv[12]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 11: M, N, K, O, G0, G1\n"); printf("arg4 to 7: sequence_length, head_dim, batch_size, head_num\n");
printf("arg10: scale (alpha)\n"); printf("arg8: scale (alpha)\n");
printf("arg11 to 12: input / output permute\n");
exit(0); exit(0);
} }
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> a_gs_ms_ks_lengths{batch_size, head_num, sequence_length, head_dim};
std::vector<ck::index_t> a_gs_ms_ks_strides = std::vector<ck::index_t> a_gs_ms_ks_strides =
input_permute input_permute ? std::vector<ck::index_t>{sequence_length * head_num * head_dim,
? std::vector<ck::index_t>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] head_dim,
: std::vector<ck::index_t>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] head_num * head_dim,
1}
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K}; // A layout [batch_size, sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * sequence_length * head_dim,
sequence_length * head_dim,
head_dim,
1}; // A layout [batch_size, head_num, sequence_length, head_dim]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{batch_size, head_num, sequence_length, head_dim};
std::vector<ck::index_t> b0_gs_ns_ks_strides = std::vector<ck::index_t> b0_gs_ns_ks_strides =
input_permute input_permute ? std::vector<ck::index_t>{sequence_length * head_num * head_dim,
? std::vector<ck::index_t>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] head_dim,
: std::vector<ck::index_t>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] head_num * head_dim,
1}
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N}; // B0 layout [batch_size, sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * sequence_length * head_dim,
sequence_length * head_dim,
head_dim,
1}; // B0 layout [batch_size, head_num, sequence_length, head_dim]
std::vector<ck::index_t> b1_gs_os_ns_lengths{batch_size, head_num, head_dim, sequence_length};
std::vector<ck::index_t> b1_gs_os_ns_strides = std::vector<ck::index_t> b1_gs_os_ns_strides =
input_permute input_permute
? std::vector<ck::index_t>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] ? std::vector<ck::index_t>{sequence_length * head_num * head_dim,
: std::vector<ck::index_t>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] head_dim,
1,
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O}; head_num * head_dim}
// B1 layout [batch_size, sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * sequence_length * head_dim,
sequence_length * head_dim,
1,
head_dim}; // B1 layout [batch_size, head_num, sequence_length, head_dim]
std::vector<ck::index_t> c_gs_ms_os_lengths{batch_size, head_num, sequence_length, head_dim};
std::vector<ck::index_t> c_gs_ms_os_strides = std::vector<ck::index_t> c_gs_ms_os_strides =
output_permute output_permute ? std::vector<ck::index_t>{sequence_length * head_num * head_dim,
? std::vector<ck::index_t>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] head_dim,
: std::vector<ck::index_t>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] head_num * head_dim,
1}
// C layout [batch_size, sequence_length, head_num, head_dim]
: std::vector<ck::index_t>{
head_num * sequence_length * head_dim,
sequence_length * head_dim,
head_dim,
1}; // C layout [batch_size, head_num, sequence_length, head_dim]
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); Tensor<B0DataType> b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
...@@ -158,9 +178,14 @@ int run(int argc, char* argv[]) ...@@ -158,9 +178,14 @@ int run(int argc, char* argv[])
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
std::vector<ck::index_t> qkv_gs_ms_ks_lengths{G0, G1, M, 3, K}; std::vector<ck::index_t> qkv_gs_ms_ks_lengths{
batch_size, head_num, sequence_length, 3, head_dim};
std::vector<ck::index_t> qkv_gs_ms_ks_strides = std::vector<ck::index_t>{ std::vector<ck::index_t> qkv_gs_ms_ks_strides = std::vector<ck::index_t>{
M * G1 * 3 * K, 3 * K, G1 * 3 * K, K, 1}; // qkv layout [G0, M, G1, 3, K] sequence_length * head_num * 3 * head_dim,
3 * head_dim,
head_num * 3 * head_dim,
head_dim,
1}; // qkv layout [batch_size, sequence_length, head_num, 3, head_dim]
Tensor<ADataType> qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides); Tensor<ADataType> qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides);
// merge qkv into a packed pointer send to device // merge qkv into a packed pointer send to device
a_gs_ms_ks.ForEach( a_gs_ms_ks.ForEach(
...@@ -190,18 +215,18 @@ int run(int argc, char* argv[]) ...@@ -190,18 +215,18 @@ int run(int argc, char* argv[])
printf("Verification: %s\n", do_verification ? "ON" : "OFF"); printf("Verification: %s\n", do_verification ? "ON" : "OFF");
// TODO ANT: replace array with vector? // TODO ANT: replace array with vector?
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void { ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_conv_mha_instance = std::get<i>(DeviceMHAFactory{}); const auto device_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_conv_mha_instance)>; using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_mha_instance)>;
auto gemm = DeviceMHAInstance{}; auto gemm = DeviceMHAInstance{};
auto invoker = gemm.MakeSelfAttnInvoker(); auto invoker = gemm.MakeSelfAttnInvoker();
auto argument = auto argument =
gemm.MakeSelfAttnArgument(static_cast<ADataType*>(qkv_device_buf.GetDeviceBuffer()), gemm.MakeSelfAttnArgument(static_cast<ADataType*>(qkv_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
G0, batch_size,
M, sequence_length,
G1, head_num,
K, head_dim,
alpha); alpha);
// if(!gemm.IsSupportedArgument(argument)) // if(!gemm.IsSupportedArgument(argument))
...@@ -211,13 +236,17 @@ int run(int argc, char* argv[]) ...@@ -211,13 +236,17 @@ int run(int argc, char* argv[])
// return 0; // return 0;
// } // }
ck::index_t BatchCount = G0 * G1; ck::index_t BatchCount = batch_size * head_num;
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; std::size_t flop = (size_t(sequence_length) * sequence_length * head_dim * 2 +
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + size_t(sequence_length) * sequence_length * head_dim * 2) *
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * sequence_length * head_dim +
sizeof(B0DataType) * head_dim * sequence_length +
sizeof(B1DataType) * sequence_length * head_dim +
sizeof(CDataType) * sequence_length * head_dim) *
BatchCount; BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -236,22 +265,25 @@ int run(int argc, char* argv[]) ...@@ -236,22 +265,25 @@ int run(int argc, char* argv[])
{ {
c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
Tensor<ADataType> a_g_m_k({BatchCount, M, K}); Tensor<ADataType> a_g_m_k({BatchCount, sequence_length, head_dim});
Tensor<B0DataType> b0_g_k_n({BatchCount, K, N}); Tensor<B0DataType> b0_g_k_n({BatchCount, head_dim, sequence_length});
Tensor<B1DataType> b1_g_n_o({BatchCount, N, O}); Tensor<B1DataType> b1_g_n_o({BatchCount, sequence_length, head_dim});
Tensor<Acc0DataType> acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 Tensor<Acc0DataType> acc0_g_m_n(
Tensor<ADataType> a1_g_m_n({BatchCount, M, N}); // scratch object after softmax {BatchCount, sequence_length, sequence_length}); // scratch object after gemm0
Tensor<CDataType> c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 Tensor<ADataType> a1_g_m_n(
{BatchCount, sequence_length, sequence_length}); // scratch object after softmax
Tensor<CDataType> c_g_m_o_host_result(
{BatchCount, sequence_length, head_dim}); // scratch object after gemm1
// permute // permute
a_gs_ms_ks.ForEach([&](auto& self, auto idx) { a_gs_ms_ks.ForEach([&](auto& self, auto idx) {
a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); a_g_m_k(idx[0] * head_num + idx[1], idx[2], idx[3]) = self(idx);
}); });
b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { b0_gs_ns_ks.ForEach([&](auto& self, auto idx) {
b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); b0_g_k_n(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx);
}); });
b1_gs_os_ns.ForEach([&](auto& self, auto idx) { b1_gs_os_ns.ForEach([&](auto& self, auto idx) {
b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); b1_g_n_o(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx);
}); });
// gemm 0 // gemm 0
...@@ -263,7 +295,7 @@ int run(int argc, char* argv[]) ...@@ -263,7 +295,7 @@ int run(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
const auto mask = typename DeviceMHAInstance::C0MatrixMask(N); const auto mask = typename DeviceMHAInstance::C0MatrixMask(sequence_length);
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
...@@ -293,7 +325,7 @@ int run(int argc, char* argv[]) ...@@ -293,7 +325,7 @@ int run(int argc, char* argv[])
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1; const size_t g = g0 * head_num + g1;
self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]);
}); });
...@@ -329,8 +361,9 @@ int run(int argc, char* argv[]) ...@@ -329,8 +361,9 @@ int run(int argc, char* argv[])
std::cout << "---------------------------------------------------------------------------------" std::cout << "---------------------------------------------------------------------------------"
"-----------" "-----------"
<< std::endl; << std::endl;
std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M std::cout << "Problem Size: BatchCount: " << batch_size << ", HeadNum: " << head_num
<< ", N: " << N << ", K: " << K << ", O: " << O << std::endl; << ", sequence_length: " << sequence_length << ", head_dim: " << head_dim
<< std::endl;
std::cout << "---------------------------------------------------------------------------------" std::cout << "---------------------------------------------------------------------------------"
"-----------" "-----------"
<< std::endl; << std::endl;
......
...@@ -83,12 +83,34 @@ using DeviceMHAFactory = ...@@ -83,12 +83,34 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32, 32,
// Gemm 0 // Gemm 0
16, 128, 64, 8, 8, 16, 32, 160, 8, 8,
// Gemm 1 // Gemm 1
64, 64, 8, 80, 32, 8,
16, 16, 16, 16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 2, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32,
// Gemm 0
16, 64, 80, 8, 8,
// Gemm 1
80, 64, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4, 1, 4, 5,
// ABlockTransfer MK -> K0 M K1 // ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1 // B0BlockTransfer LK -> K0 L K1
...@@ -105,12 +127,12 @@ using DeviceMHAFactory = ...@@ -105,12 +127,12 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
32, 32,
// Gemm 0 // Gemm 0
16, 64, 64, 8, 8, 16, 64, 48, 8, 8,
// Gemm 1 // Gemm 1
64, 64, 8, 48, 64, 8,
16, 16, 16, 16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4, 1, 4, 3,
// ABlockTransfer MK -> K0 M K1 // ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1 // B0BlockTransfer LK -> K0 L K1
...@@ -129,16 +151,16 @@ using DeviceMHAFactory = ...@@ -129,16 +151,16 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64, 64,
// Gemm 0 // Gemm 0
32, 128, 64, 8, 8, 32, 64, 48, 8, 8,
// Gemm 1 // Gemm 1
64, 64, 8, 48, 64, 8,
16, 16, 16, 16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4, 1, 4, 3,
// ABlockTransfer MK -> K0 M K1 // ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1 // B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1 // B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
...@@ -151,16 +173,38 @@ using DeviceMHAFactory = ...@@ -151,16 +173,38 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64, 64,
// Gemm 0 // Gemm 0
32, 64, 64, 8, 8, 32, 64, 80, 8, 8,
// Gemm 1 // Gemm 1
64, 64, 8, 80, 64, 8,
16, 16, 16, 16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4, 1, 4, 5,
// ABlockTransfer MK -> K0 M K1 // ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1 // B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
64,
// Gemm 0
32, 32, 160, 8, 8,
// Gemm 1
80, 32, 8,
16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 2, 5,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1 // B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
...@@ -175,20 +219,20 @@ using DeviceMHAFactory = ...@@ -175,20 +219,20 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128, 128,
// Gemm 0 // Gemm 0
64, 128, 64, 8, 8, 64, 128, 80, 8, 8,
// Gemm 1 // Gemm 1
64, 64, 8, 80, 64, 8,
16, 16, 16, 16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4, 1, 8, 5,
// ABlockTransfer MK -> K0 M K1 // ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1 // B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1 // B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8, 1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>, MaskingSpec>,
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
...@@ -197,45 +241,45 @@ using DeviceMHAFactory = ...@@ -197,45 +241,45 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
128, 128,
// Gemm 0 // Gemm 0
64, 64, 64, 8, 8, 64, 192, 48, 8, 8,
// Gemm 1 // Gemm 1
64, 64, 8, 48, 64, 8,
16, 16, 16, 16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4, 1, 12, 3,
// ABlockTransfer MK -> K0 M K1 // ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1 // B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1 // B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8, 1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>, MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256, 128,
// Gemm 0 // Gemm 0
128, 128, 64, 8, 8, 64, 64, 48, 8, 8,
// Gemm 1 // Gemm 1
64, 64, 8, 48, 64, 8,
16, 16, 16, 16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4, 1, 4, 3,
// ABlockTransfer MK -> K0 M K1 // ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1 // B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1 // B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8, 1, 1, S<1, 64, 1, 2>, 8,
MaskingSpec>, MaskingSpec>,
#endif
#ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
...@@ -243,18 +287,18 @@ using DeviceMHAFactory = ...@@ -243,18 +287,18 @@ using DeviceMHAFactory =
GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1,
256, 256,
// Gemm 0 // Gemm 0
128, 128, 64, 8, 8, 128, 192, 48, 8,4,
// Gemm 1 // Gemm 1
64, 64, 8, 48, 64, 8,
16, 16, 16, 16, 16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1 // Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4, 1, 12, 3,
// ABlockTransfer MK -> K0 M K1 // ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1 // B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true,
// B1BlockTransfer NL -> L0 N L1 // B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN // CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8, 1, 1, S<1, 128, 1, 2>, 8,
MaskingSpec> MaskingSpec>
......
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") if(GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_fpAintB_gemm_wmma) add_custom_target(example_fpAintB_gemm_wmma)
add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp) add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp)
add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma) add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma)
......
...@@ -56,8 +56,7 @@ __global__ void ...@@ -56,8 +56,7 @@ __global__ void
bool input_permute, bool input_permute,
bool output_permute) bool output_permute)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -162,7 +161,7 @@ __global__ void ...@@ -162,7 +161,7 @@ __global__ void
ignore = G1; ignore = G1;
ignore = input_permute; ignore = input_permute;
ignore = output_permute; ignore = output_permute;
#endif // end of if (defined(__gfx1100__)) #endif // end of if (defined(__gfx11__))
} }
// Self-Attention // Self-Attention
...@@ -188,8 +187,7 @@ __global__ void ...@@ -188,8 +187,7 @@ __global__ void
index_t head_size, index_t head_size,
float alpha) float alpha)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -294,7 +292,7 @@ __global__ void ...@@ -294,7 +292,7 @@ __global__ void
ignore = head_count; ignore = head_count;
ignore = head_size; ignore = head_size;
ignore = alpha; ignore = alpha;
#endif // end of if (defined(__gfx1100__)) #endif // end of if (defined(__gfx11__))
} }
// Cross-Attention // Cross-Attention
// Self-Attention // Self-Attention
...@@ -323,8 +321,7 @@ __global__ void ...@@ -323,8 +321,7 @@ __global__ void
index_t head_size, index_t head_size,
float alpha) float alpha)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -435,7 +432,7 @@ __global__ void ...@@ -435,7 +432,7 @@ __global__ void
ignore = head_count; ignore = head_count;
ignore = head_size; ignore = head_size;
ignore = alpha; ignore = alpha;
#endif // end of if (defined(__gfx1100__)) #endif // end of if (defined(__gfx11__))
} }
// Computes C = A * B0 * B1 // Computes C = A * B0 * B1
// MN = MK * KL * LN // MN = MK * KL * LN
...@@ -861,8 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -861,8 +858,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static bool IsSupportedArgument(const RawArg& arg) static bool IsSupportedArgument(const RawArg& arg)
{ {
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
...@@ -1439,8 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle ...@@ -1439,8 +1435,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
#if 0 #if 0
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
......
...@@ -509,8 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -509,8 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> || if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>)) is_same_v<AccDataType, int32_t>))
......
...@@ -498,94 +498,95 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -498,94 +498,95 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
} }
}; };
static bool IsSupportedArgument(const Argument& arg) static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
// check vector load/store // check vector load/store
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{ {
using Row = ck::tensor_layout::gemm::RowMajor; if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{ {
return false; return false;
} }
}
// check vector laod of B else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2) {
// FIXME: not rigorous
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{ {
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) return false;
{
return false;
}
} }
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1) }
else
{
return false;
}
// check vector laod of B
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{
if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
{ {
// FIXME: not rigorous return false;
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
} }
else }
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{ {
return false; return false;
} }
}
else
{
return false;
}
// check vector load of Ds // check vector load of Ds
// only support RowMajor for now // only support RowMajor for now
bool all_valid = true; bool all_valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(!is_same_v<DLayout, Row>) static_for<0, NumDTensor, 1>{}([&](auto i) {
{ using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
all_valid = false;
}
});
if(!all_valid) if constexpr(!is_same_v<DLayout, Row>)
{ {
return false; all_valid = false;
} }
});
// check vector store of E if(!all_valid)
// only support RowMajor for now {
if constexpr(is_same_v<ELayout, Row>) return false;
{ }
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{ // check vector store of E
return false; // only support RowMajor for now
} if constexpr(is_same_v<ELayout, Row>)
} {
else if(NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{ {
return false; return false;
} }
} }
else
{
return false;
}
return true;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
return IsSupported(arg.MRaw_, arg.NRaw_, arg.KRaw_) and
GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
...@@ -708,6 +709,178 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -708,6 +709,178 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return str.str(); return str.str();
} }
template <class ADesc, class BDesc, class DsDesc, class EDesc>
struct Descriptor
{
static constexpr auto ds_tuple()
{
return transform_tuples(
[&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
DsDesc{});
}
using AGridDesc_M_K =
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{}))>;
using BGridDesc_N_K =
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(ds_tuple())>;
using EGridDesc_M_N =
remove_cvref_t<decltype(DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{}))>;
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{})))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{})))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_tuple()))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
using Block2ETileMap = remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k;
BGridDesc_N_K b_grid_desc_n_k;
DsGridDesc_M_N ds_grid_desc_m_n;
EGridDesc_M_N e_grid_desc_m_n;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
// block-to-e-tile map
Block2ETileMap block_2_etile_map;
// element-wise op
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
CDEElementwiseOperation cde_element_op;
// for checking vector load/store
index_t MRaw;
index_t NRaw;
index_t KRaw;
bool has_main_k_block_loop = true;
constexpr Descriptor(ADesc a,
BDesc b,
DsDesc ds,
EDesc e,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CDEElementwiseOperation cde_element_op_)
: a_grid_desc_m_k{DeviceOp::matrix_padder.PadADescriptor_M_K(a)},
b_grid_desc_n_k{DeviceOp::matrix_padder.PadBDescriptor_N_K(b)},
ds_grid_desc_m_n{transform_tuples(
[&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
ds)},
e_grid_desc_m_n{DeviceOp::matrix_padder.PadCDescriptor_M_N(e)},
a_grid_desc_ak0_m_ak1{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k)},
b_grid_desc_bk0_n_bk1{
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k)},
ds_grid_desc_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
transform_tuples(
[&](auto d) constexpr {
return DeviceOp::matrix_padder.PadCDescriptor_M_N(d);
},
ds))},
e_grid_desc_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n)},
block_2_etile_map{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n)},
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
cde_element_op{cde_element_op_},
MRaw{e.GetLength(I0)},
NRaw{e.GetLength(I1)},
KRaw{a.GetLength(I1)}
{
}
constexpr bool IsValid() const
{
return GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
block_2_etile_map) and
IsSupported(MRaw, NRaw, KRaw);
}
constexpr index_t GetBlockSize() const { return BlockSize; }
constexpr index_t GetGridSize() const
{
return block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
}
};
template <class ADesc, class BDesc, class DsDesc, class EDesc>
static constexpr auto
make_descriptor(ADesc a,
BDesc b,
DsDesc ds,
EDesc e,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
CDEElementwiseOperation cde_element_op = CDEElementwiseOperation{})
{
return Descriptor<ADesc, BDesc, DsDesc, EDesc>(
a, b, ds, e, a_element_op, b_element_op, cde_element_op);
}
template <class Desc, class DsPointer>
__device__ static void Run(const Desc& desc,
const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid)
{
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
assert(desc.IsValid());
if(desc.has_main_k_block_loop)
{
GridwiseGemm::template Run<true>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
desc.cde_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
desc.block_2_etile_map);
}
else
{
GridwiseGemm::template Run<false>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
desc.cde_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
desc.block_2_etile_map);
}
}
}; };
} // namespace device } // namespace device
......
...@@ -61,8 +61,7 @@ __global__ void ...@@ -61,8 +61,7 @@ __global__ void
bool input_permute, bool input_permute,
bool output_permute) bool output_permute)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
// clang-format off // clang-format off
// *************************************************** // ***************************************************
...@@ -169,7 +168,7 @@ __global__ void ...@@ -169,7 +168,7 @@ __global__ void
ignore = G1; ignore = G1;
ignore = input_permute; ignore = input_permute;
ignore = output_permute; ignore = output_permute;
#endif // end of if (defined(__gfx1100__)) #endif // end of if (defined(__gfx11__))
} }
// Computes C = A * B0 * B1 // Computes C = A * B0 * B1
...@@ -597,8 +596,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma ...@@ -597,8 +596,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg) static bool IsSupportedArgument(const RawArg& arg)
{ {
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
...@@ -960,8 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma ...@@ -960,8 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
#if 0 #if 0
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>)) if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment