Commit c54b7bc9 authored by Chao Liu's avatar Chao Liu
Browse files

gMerge remote-tracking branch 'origin/develop' into group_norm

parents 9a8967a4 f584ab0c
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16 padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_custom_target(example_batched_gemm_scale_softmax_gemm) add_custom_target(example_gemm_scale_softmax_gemm)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_batched_gemm_scale_softmax_gemm example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16)
...@@ -58,7 +58,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; ...@@ -58,7 +58,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNOPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
...@@ -149,8 +149,8 @@ int main(int argc, char* argv[]) ...@@ -149,8 +149,8 @@ int main(int argc, char* argv[])
// GEMM shape for A/B0/B1/C // GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 128; ck::index_t M = 120;
ck::index_t N = 1024; ck::index_t N = 1000;
ck::index_t K = 64; ck::index_t K = 64;
ck::index_t O = 128; ck::index_t O = 128;
ck::index_t StrideA = -1; ck::index_t StrideA = -1;
......
...@@ -55,7 +55,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; ...@@ -55,7 +55,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ALayout, ALayout,
...@@ -73,7 +73,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftma ...@@ -73,7 +73,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftma
Acc0ElementOp, Acc0ElementOp,
B1ElementOp, B1ElementOp,
CElementOp, CElementOp,
GemmDefault, GemmSpec,
1, 1,
256, 256,
128, // MPerBlock 128, // MPerBlock
...@@ -144,8 +144,8 @@ int main(int argc, char* argv[]) ...@@ -144,8 +144,8 @@ int main(int argc, char* argv[])
bool time_kernel = false; bool time_kernel = false;
// GEMM shape // GEMM shape
ck::index_t M = 1024; ck::index_t M = 1020;
ck::index_t N = 1024; ck::index_t N = 1020;
ck::index_t K = 64; ck::index_t K = 64;
ck::index_t O = 128; ck::index_t O = 128;
ck::index_t BatchCount = 4; ck::index_t BatchCount = 4;
......
...@@ -16,7 +16,8 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g ...@@ -16,7 +16,8 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -47,7 +48,9 @@ using CDataType = F16; ...@@ -47,7 +48,9 @@ using CDataType = F16;
using ALayout = Row; using ALayout = Row;
using B0Layout = Col; using B0Layout = Col;
using B1Layout = Row; using B1Layout = Row;
using CLayout = Row;
using CPermuteNumDims_G_M_O =
S<1, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_M_O
using AElementOp = PassThrough; using AElementOp = PassThrough;
using B0ElementOp = PassThrough; using B0ElementOp = PassThrough;
...@@ -55,65 +58,66 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; ...@@ -55,65 +58,66 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< using DeviceGemmInstance =
ALayout, ck::tensor_operation::device::DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle<
B0Layout, ALayout,
B1Layout, B0Layout,
CLayout, B1Layout,
ADataType, CPermuteNumDims_G_M_O,
B0DataType, ADataType,
B1DataType, B0DataType,
CDataType, B1DataType,
AccDataType, CDataType,
CShuffleDataType, AccDataType,
AElementOp, CShuffleDataType,
B0ElementOp, AElementOp,
Acc0ElementOp, B0ElementOp,
B1ElementOp, Acc0ElementOp,
CElementOp, B1ElementOp,
MNPadding, CElementOp,
1, GemmSpec,
256, 1,
128, // MPerBlock 256,
128, // NPerBlock 128, // MPerBlock
32, // KPerBlock 128, // NPerBlock
64, // Gemm1NPerBlock 32, // KPerBlock
32, // Gemm1KPerBlock 64, // Gemm1NPerBlock
8, // AK1 32, // Gemm1KPerBlock
8, // BK1 8, // AK1
2, // B1K1 8, // BK1
32, // MPerXDL 2, // B1K1
32, // NPerXDL 32, // MPerXDL
1, // MXdlPerWave 32, // NPerXDL
4, // NXdlPerWave 1, // MXdlPerWave
2, // Gemm1NXdlPerWave 4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransfer 2, // Gemm1NXdlPerWave
S<1, 0, 2>, S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
2, S<1, 0, 2>,
8, 2,
8, 8,
true, 8,
S<4, 64, 1>, // BBlockTransfer true,
S<1, 0, 2>, S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
2, S<1, 0, 2>,
8, 2,
8, 8,
true, 8,
S<16, 16, 1>, // B1BlockTransfer true,
S<0, 2, 1>, S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
1, S<0, 2, 1>,
4, 1,
2, 4,
false, 2,
1, // CShuffleMXdlPerWavePerShuffle false,
2, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 2, // CShuffleNXdlPerWavePerShuffle
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
// Ref Gemm0: fp16 in, fp32 out // Ref Gemm0: fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
...@@ -143,22 +147,6 @@ int main(int argc, char* argv[]) ...@@ -143,22 +147,6 @@ int main(int argc, char* argv[])
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = false;
// GEMM shape
ck::index_t M = 1020;
ck::index_t N = 1020;
ck::index_t K = 64;
ck::index_t O = 128;
ck::index_t BatchCount = 4;
ck::index_t StrideA = -1;
ck::index_t StrideB0 = -1;
ck::index_t StrideB1 = -1;
ck::index_t StrideC = -1;
ck::index_t BatchStrideA = -1;
ck::index_t BatchStrideB0 = -1;
ck::index_t BatchStrideB1 = -1;
ck::index_t BatchStrideC = -1;
float alpha = 1;
if(argc == 1) if(argc == 1)
{ {
// use default case // use default case
...@@ -169,74 +157,58 @@ int main(int argc, char* argv[]) ...@@ -169,74 +157,58 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
}
else if(argc == 18)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
BatchCount = std::stoi(argv[8]);
StrideA = std::stoi(argv[9]);
StrideB0 = std::stoi(argv[10]);
StrideB1 = std::stoi(argv[11]);
StrideC = std::stoi(argv[12]);
BatchStrideA = std::stoi(argv[13]);
BatchStrideB0 = std::stoi(argv[14]);
BatchStrideB1 = std::stoi(argv[15]);
BatchStrideC = std::stoi(argv[16]);
alpha = std::stof(argv[17]);
}
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 16: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, "
"BatchStrideB0, BatchStrideB1, BatchStrideC\n");
printf("arg17: scale (alpha)\n");
exit(0); exit(0);
} }
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M; float alpha = 1; // scaling after 1st gemm
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? O : M;
StrideA = (StrideA < 0) ? DefaultStrideA : StrideA; std::size_t group_count = 13;
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
StrideC = (StrideC < 0) ? DefaultStrideC : StrideC;
const int DefaultBatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA; // Problem descs
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0; std::vector<DeviceGemmInstance::ProblemDesc> problem_descs;
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1; std::vector<const void*> p_a;
const int DefaultBatchStrideC = (ck::is_same_v<CLayout, Col> ? O : M) * StrideC; std::vector<const void*> p_b0;
std::vector<const void*> p_b1;
std::vector<void*> p_c;
BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA; for(std::size_t i = 0; i < group_count; i++)
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; {
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; int M = 128 * (rand() % 8 + 1);
BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC; int N = 128 * (rand() % 8 + 1);
int K = 64;
int O = 64 * (rand() % 2 + 1);
int Batch = rand() % 8 + 1;
const int StrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int StrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int StrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
const int BatchStrideA = (ck::is_same_v<ALayout, Col> ? K : M) * StrideA;
const int BatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
const int BatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
std::vector<ck::index_t> c_gs_ms_os_lengths{Batch, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{O, Batch * O, 1};
problem_descs.push_back({M,
N,
K,
O,
Batch,
StrideA,
StrideB0,
StrideB1,
BatchStrideA,
BatchStrideB0,
BatchStrideB1,
c_gs_ms_os_lengths,
c_gs_ms_os_strides});
}
auto f_host_tensor_descriptor = [](std::size_t batch_count, auto f_host_tensor_descriptor = [](std::size_t batch_count,
std::size_t row, std::size_t row,
...@@ -256,56 +228,108 @@ int main(int argc, char* argv[]) ...@@ -256,56 +228,108 @@ int main(int argc, char* argv[])
} }
}; };
// C_m_o = A_m_k * B0_k_n * B1_n_o std::vector<Tensor<ADataType>> a_tensors;
Tensor<ADataType> a_g_m_k( std::vector<Tensor<B0DataType>> b0_tensors;
f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{})); std::vector<Tensor<B1DataType>> b1_tensors;
Tensor<B0DataType> b0_g_k_n( std::vector<Tensor<CDataType>> c_tensors;
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
Tensor<B1DataType> b1_g_n_o( using DeviceMemPtr = std::unique_ptr<DeviceMem>;
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_g_m_o_host_result( std::vector<DeviceMemPtr> a_tensors_device;
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); std::vector<DeviceMemPtr> b0_tensors_device;
Tensor<CDataType> c_g_m_o_device_result( std::vector<DeviceMemPtr> b1_tensors_device;
f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); std::vector<DeviceMemPtr> c_tensors_device;
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::size_t flop = 0, num_byte = 0;
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; std::cout << "group count " << group_count << ". printing first 4 groups\n";
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl; for(std::size_t i = 0; i < group_count; i++)
switch(init_method)
{ {
case 0: break; const auto& M = problem_descs[i].M;
case 1: const auto& N = problem_descs[i].N;
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); const auto& K = problem_descs[i].K;
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5}); const auto& O = problem_descs[i].O;
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5}); const auto& Batch = problem_descs[i].Batch;
break; const auto& StrideA = problem_descs[i].StrideA;
case 2: const auto& StrideB0 = problem_descs[i].StrideB0;
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); const auto& StrideB1 = problem_descs[i].StrideB1;
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0}); const auto& BatchStrideA = problem_descs[i].BatchStrideA;
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5}); const auto& BatchStrideB0 = problem_descs[i].BatchStrideB0;
break; const auto& BatchStrideB1 = problem_descs[i].BatchStrideB1;
case 3: const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths;
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2}); const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides;
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); // C_m_o = A_m_k * B0_k_n * B1_n_o
break; Tensor<ADataType> a_g_m_k(
default: f_host_tensor_descriptor(Batch, M, K, StrideA, BatchStrideA, ALayout{}));
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); Tensor<B0DataType> b0_g_k_n(
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); f_host_tensor_descriptor(Batch, K, N, StrideB0, BatchStrideB0, B0Layout{}));
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); Tensor<B1DataType> b1_g_n_o(
} f_host_tensor_descriptor(Batch, N, O, StrideB1, BatchStrideB1, B1Layout{}));
Tensor<CDataType> c_gs_ms_os_device_result(
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch;
num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
Batch;
if(i < 4)
{
std::cout << "a_g_m_k[" << i << "]: " << a_g_m_k.mDesc << ", "
<< "b0_g_k_n[" << i << "]: " << b0_g_k_n.mDesc << ", "
<< "b1_g_n_o[" << i << "]: " << b1_g_n_o.mDesc << ", "
<< "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << std::endl;
}
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); switch(init_method)
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize()); {
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize()); case 0: break;
DeviceMem c_g_m_o_device_buf(sizeof(CDataType) * case 1:
c_g_m_o_device_result.mDesc.GetElementSpaceSize()); a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 2:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
case 3:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); a_tensors.push_back(a_g_m_k);
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); b0_tensors.push_back(b0_g_k_n);
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); b1_tensors.push_back(b1_g_n_o);
c_tensors.push_back(c_gs_ms_os_device_result);
a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()));
b0_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize()));
b1_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize()));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()));
a_tensors_device[i]->ToDevice(a_g_m_k.mData.data());
b0_tensors_device[i]->ToDevice(b0_g_k_n.mData.data());
b1_tensors_device[i]->ToDevice(b1_g_n_o.mData.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer());
p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
}
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{}; auto b0_element_op = B0ElementOp{};
...@@ -314,31 +338,23 @@ int main(int argc, char* argv[]) ...@@ -314,31 +338,23 @@ int main(int argc, char* argv[])
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = auto argument = gemm.MakeArgument(p_a,
gemm.MakeArgument(static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()), p_b0,
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()), p_b1,
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()), p_c,
static_cast<CDataType*>(c_g_m_o_device_buf.GetDeviceBuffer()), problem_descs,
M, a_element_op,
N, b0_element_op,
K, acc0_element_op,
O, b1_element_op,
BatchCount, c_element_op);
StrideA,
StrideB0, // specify workspace for problem_desc
StrideB1, DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
StrideC,
BatchStrideA, gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer());
BatchStrideB0,
BatchStrideB1,
BatchStrideC,
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -349,49 +365,79 @@ int main(int argc, char* argv[]) ...@@ -349,49 +365,79 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data()); bool pass = true;
if(do_verification) if(do_verification)
{ {
// Output of Gemm0 is input A of Gemm1 for(std::size_t i = 0; i < group_count; i++)
Tensor<AccDataType> acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); {
const auto& M = problem_descs[i].M;
const auto& N = problem_descs[i].N;
const auto& O = problem_descs[i].O;
const auto& Batch = problem_descs[i].Batch;
const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths;
const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides;
const auto& a_g_m_k = a_tensors[i];
const auto& b0_g_k_n = b0_tensors[i];
const auto& b1_g_n_o = b1_tensors[i];
auto& c_gs_ms_os_device_result = c_tensors[i];
auto& c_gs_ms_os_device_buf = *c_tensors_device[i];
Tensor<CDataType> c_gs_ms_os_host_result(
std::vector<std::size_t>(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()),
std::vector<std::size_t>(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end()));
Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
auto ref_gemm0 = ReferenceGemm0Instance{}; // Output of Gemm0 is input A of Gemm1
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); Tensor<AccDataType> acc0_m_n(f_host_tensor_descriptor(Batch, M, N, N, M * N, Row{}));
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument); Tensor<ADataType> a1_g_m_n(f_host_tensor_descriptor(Batch, M, N, N, M * N, Row{}));
auto ref_softmax = ReferenceSoftmaxInstance{}; Tensor<CDataType> c_g_m_o_host_result(std::vector<int>{Batch, M, O},
auto ref_softmax_invoker = ref_softmax.MakeInvoker(); std::vector<int>{M * O, O, 1});
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2});
ref_softmax_invoker.Run(ref_softmax_argument); auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_m_n, a_element_op, b0_element_op, acc0_element_op);
auto ref_gemm1 = ReferenceGemm1Instance{}; ref_gemm0_invoker.Run(ref_gemm0_argument);
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument); auto ref_softmax = ReferenceSoftmaxInstance{};
auto ref_softmax_invoker = ref_softmax.MakeInvoker();
auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_m_n, a1_g_m_n, 1, 0, {2});
return ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData) ? 0 : 1; ref_softmax_invoker.Run(ref_softmax_argument);
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n,
b1_g_n_o,
c_g_m_o_host_result,
PassThrough{},
b1_element_op,
c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
// Note: in this example, we merely permute the dimensions by changing underlying
// strides so we simply access data as-is
c_gs_ms_os_host_result.ForEach(
[&](auto& self, auto idx) { self(idx) = c_g_m_o_host_result(idx); });
bool pass_ =
ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData);
pass &= pass_;
}
} }
return 0; return pass ? 0 : 1;
} }
add_custom_target(example_permute)
add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp)
add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp)
add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp)
add_dependencies(example_permute example_permute_1xHxW_fp16)
add_dependencies(example_permute example_permute_NxHxW_fp16)
add_dependencies(example_permute example_permute_HxWx4_fp16)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <iterator>
#include <numeric>
#include <type_traits>
#include <utility>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/utility/type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
using F16 = ck::half_t;
using F32 = float;
using F64 = double;
struct Problem final
{
static constexpr std::size_t NumDim = 3;
using Shape = std::array<std::size_t, NumDim>;
using Axes = Shape;
Problem() = delete;
explicit Problem(const Shape& default_shape, const Axes& default_axes)
: shape(default_shape), axes(default_axes)
{
}
Shape shape;
Axes axes;
};
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
namespace detail {
template <typename Array, std::size_t Difference>
struct enlarge_array_size;
template <typename T, std::size_t Size, std::size_t Difference>
struct enlarge_array_size<std::array<T, Size>, Difference>
{
using type = std::array<T, Size + Difference>;
};
template <typename Array, std::size_t Difference>
using enlarge_array_size_t = typename enlarge_array_size<Array, Difference>::type;
template <typename Array>
struct get_array_size;
template <typename T, std::size_t Size>
struct get_array_size<std::array<T, Size>> : std::integral_constant<std::size_t, Size>
{
};
template <typename Array>
inline constexpr std::size_t get_array_size_v = get_array_size<Array>::value;
template <typename T, typename = void>
struct is_iterator : std::false_type
{
};
template <typename T>
struct is_iterator<T,
std::void_t<decltype(*std::declval<T>()),
decltype(++std::declval<std::add_lvalue_reference_t<T>>()),
decltype(std::declval<std::add_lvalue_reference_t<T>>()++)>>
: std::true_type
{
};
template <typename T>
inline constexpr bool is_iterator_v = is_iterator<T>::value;
struct Placeholder final
{
template <typename T>
constexpr inline operator T() const noexcept;
};
template <typename Iterator, typename = void>
struct is_output_iterator : std::false_type
{
};
template <typename Iterator>
struct is_output_iterator<
Iterator,
std::void_t<decltype(*std::declval<Iterator>() = std::declval<Placeholder>())>>
: std::bool_constant<is_iterator_v<Iterator>>
{
};
template <typename T>
inline constexpr bool is_output_iterator_v = is_output_iterator<T>::value;
template <typename Iterator, typename = void>
struct is_bidirectional_iterator : std::false_type
{
};
template <typename Iterator>
struct is_bidirectional_iterator<
Iterator,
std::void_t<decltype(--std::declval<std::add_lvalue_reference_t<Iterator>>()),
decltype(std::declval<std::add_lvalue_reference_t<Iterator>>()--)>>
: std::bool_constant<is_iterator_v<Iterator>>
{
};
template <typename Iterator>
inline constexpr bool is_bidirectional_iterator_v = is_bidirectional_iterator<Iterator>::value;
template <typename Iterator, typename = void>
struct is_random_access_iterator : std::false_type
{
};
template <typename Iterator>
struct is_random_access_iterator<Iterator,
std::void_t<decltype(std::declval<Iterator>() + 1),
decltype(std::declval<Iterator>() - 1),
decltype(std::declval<Iterator>()[1])>>
: std::bool_constant<is_iterator_v<Iterator>>
{
};
template <typename Iterator>
inline constexpr bool is_random_access_iterator_v = is_random_access_iterator<Iterator>::value;
template <typename T, typename = void>
struct is_range : std::false_type
{
};
template <typename T>
struct is_range<T,
std::void_t<decltype(begin(std::declval<T>())),
decltype(end(std::declval<T>())),
decltype(begin(std::declval<T>()) != end(std::declval<T>()))>>
: std::bool_constant<is_iterator_v<ck::remove_cvref_t<decltype(begin(std::declval<T>()))>>>
{
};
template <typename T>
inline constexpr bool is_range_v = is_range<T>::value;
template <typename Range, typename = void>
struct is_sized_range : std::false_type
{
};
template <typename Range>
struct is_sized_range<Range, std::void_t<decltype(size(std::declval<Range>()))>>
: std::bool_constant<is_range_v<Range>>
{
};
template <typename Range>
inline constexpr bool is_sized_range_v = is_sized_range<Range>::value;
template <typename Range, typename = void>
struct is_bidirectional_range : std::false_type
{
};
template <typename Range>
struct is_bidirectional_range<Range, std::void_t<>>
: std::bool_constant<
is_range_v<Range> &&
is_bidirectional_iterator_v<ck::remove_cvref_t<decltype(begin(std::declval<Range>()))>>>
{
};
template <typename Range>
inline constexpr bool is_bidirectional_range_v = is_bidirectional_range<Range>::value;
template <typename Range, typename = void>
struct is_random_access_range : std::false_type
{
};
template <typename Range>
struct is_random_access_range<Range, std::void_t<>>
: std::bool_constant<
is_range_v<Range> &&
is_random_access_iterator_v<ck::remove_cvref_t<decltype(begin(std::declval<Range>()))>>>
{
};
template <typename Range>
inline constexpr bool is_random_access_range_v = is_random_access_range<Range>::value;
template <typename Range>
class to_array_proxy
{
static_assert(is_range_v<Range>);
public:
explicit to_array_proxy(const Range& source) noexcept : source_(source) {}
template <typename T, std::size_t Size>
operator std::array<T, Size>() const
{
std::array<T, Size> destination;
std::copy_n(std::begin(source_),
std::min<std::size_t>(Size, std::size(source_)),
std::begin(destination));
return destination;
}
private:
const Range& source_;
};
} // namespace detail
template <typename Range>
inline auto to_array(Range& range) noexcept
-> std::enable_if_t<detail::is_range_v<Range>,
detail::to_array_proxy<ck::remove_cvref_t<Range>>>
{
return detail::to_array_proxy<ck::remove_cvref_t<Range>>{range};
}
namespace ranges {
template <typename InputRange, typename OutputIterator>
inline auto copy(InputRange&& range, OutputIterator iter)
-> decltype(std::copy(std::begin(std::forward<InputRange>(range)),
std::end(std::forward<InputRange>(range)),
iter))
{
return std::copy(std::begin(std::forward<InputRange>(range)),
std::end(std::forward<InputRange>(range)),
iter);
}
} // namespace ranges
template <typename Axes>
inline auto is_valid_axes(const Axes& axes)
-> std::enable_if_t<detail::is_random_access_range_v<Axes>, bool>
{
using std::empty;
if(empty(axes))
{
return false;
}
using std::begin, std::end;
std::vector<std::size_t> sorted_axes(begin(axes), end(axes));
std::sort(begin(sorted_axes), end(sorted_axes));
const auto last = std::unique(begin(sorted_axes), end(sorted_axes));
return (last == end(sorted_axes)) && (*begin(sorted_axes) == 0) &&
(*std::prev(last) == size(axes) - 1);
}
template <typename Shape>
inline auto is_valid_shape(const Shape& shape) -> std::enable_if_t<detail::is_range_v<Shape>, bool>
{
static_assert(std::is_unsigned_v<ck::remove_cvref_t<decltype(*std::begin(shape))>>);
using std::begin, std::end;
using std::empty;
return !empty(shape) && std::all_of(begin(shape), end(shape), [](auto dim) { return 0 < dim; });
}
template <typename Shape, typename Indices>
inline auto is_valid_indices(const Shape& shape, const Indices& indices)
-> std::enable_if_t<detail::is_sized_range_v<Shape> && detail::is_sized_range_v<Indices>, bool>
{
static_assert(std::is_unsigned_v<ck::remove_cvref_t<decltype(*std::begin(indices))>>);
if(!is_valid_shape(shape))
{
return false;
}
using std::empty;
if(empty(indices))
{
return false;
}
using std::size;
if(size(shape) != size(indices))
{
return false;
}
using std::begin, std::end;
auto dim = begin(shape);
auto idx = begin(indices);
for(; dim != end(shape) && idx != end(indices); ++dim, ++idx)
{
if(*dim <= *idx)
{
return false;
}
}
return true;
}
template <std::size_t Size>
std::array<std::size_t, Size> transpose(const std::array<std::size_t, Size>& shape,
const std::array<std::size_t, Size>& axes)
{
assert(is_valid_shape(shape) && is_valid_axes(axes));
std::array<std::size_t, Size> transposed;
auto iter = std::begin(transposed);
for(const auto axis : axes)
{
*iter++ = shape[axis];
}
return transposed;
}
auto extend_shape(const Problem::Shape& shape, std::size_t new_dim)
{
detail::enlarge_array_size_t<Problem::Shape, 1> extended_shape;
using std::begin, std::end;
std::copy(begin(shape), end(shape), begin(extended_shape));
extended_shape.back() = new_dim;
return extended_shape;
}
auto extend_axes(const Problem::Axes& axes)
{
detail::enlarge_array_size_t<Problem::Axes, 1> extended_axes;
using std::begin, std::end;
std::copy(begin(axes), end(axes), begin(extended_axes));
extended_axes.back() = detail::get_array_size_v<Problem::Axes>;
return extended_axes;
}
template <typename Shape, typename Indices>
auto advance_indices(const Shape& shape, Indices& indices) -> std::enable_if_t<
detail::is_bidirectional_range_v<Shape> && detail::is_sized_range_v<Shape> &&
detail::is_bidirectional_range_v<Indices> && detail::is_sized_range_v<Indices>,
bool>
{
using std::size;
if(!(is_valid_shape(shape) && is_valid_indices(shape, indices) && size(shape) == size(indices)))
{
return false;
}
bool carry = true;
using std::rbegin, std::rend;
auto dim = rbegin(shape);
auto idx = rbegin(indices);
for(; carry && dim != rend(shape) && idx != rend(indices); ++dim, ++idx)
{
*idx = (*idx + carry);
carry = ((*idx == *dim) ? (*idx = 0, true) : false);
}
return !carry;
}
template <typename Src, typename Axes, typename Functor, typename Dest>
auto host_permute(const Tensor<Src>& src, const Axes& axes, Functor functor, Tensor<Dest>& dest)
-> std::enable_if_t<detail::is_random_access_range_v<Axes> && detail::is_sized_range_v<Axes> &&
std::is_invocable_v<Functor,
std::add_lvalue_reference_t<Dest>,
std::add_lvalue_reference_t<Src>>,
bool>
{
const auto& shape = src.mDesc.GetLengths();
const auto& transposed_shape = dest.mDesc.GetLengths();
if(!(is_valid_shape(shape) && is_valid_shape(transposed_shape)))
{
return false;
}
using std::size;
if(!is_valid_axes(axes))
{
return false;
}
static_assert(detail::is_sized_range_v<ck::remove_cvref_t<decltype(shape)>> &&
detail::is_sized_range_v<ck::remove_cvref_t<decltype(transposed_shape)>>);
if(size(shape) != size(transposed_shape))
{
return false;
}
static_assert(detail::is_random_access_range_v<ck::remove_cvref_t<decltype(shape)>> &&
detail::is_random_access_range_v<ck::remove_cvref_t<decltype(transposed_shape)>>);
{
for(std::size_t idx = 0; idx < size(shape); ++idx)
{
if(transposed_shape[idx] != shape[axes[idx]])
{
return false;
}
}
}
std::vector<std::size_t> indices(size(shape), 0);
if(!is_valid_indices(shape, indices))
{
return false;
}
switch(size(shape))
{
case 3: {
do
{
Dest output = 0;
functor(output, src(indices[0], indices[1], indices[2]));
dest(indices[axes[0]], indices[axes[1]], indices[axes[2]]) = output;
} while(advance_indices(shape, indices));
}
break;
case 4: {
do
{
Dest output = 0;
functor(output, src(indices[0], indices[1], indices[2], indices[3]));
dest(indices[axes[0]], indices[axes[1]], indices[axes[2]], indices[axes[3]]) = output;
} while(advance_indices(shape, indices));
}
break;
default: return false;
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using InDataType = F16;
using OutDataType = F16;
// clang-format off
using DevicePermuteInstance = ck::tensor_operation::device::DevicePermuteImpl
// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | |
< 3, InDataType, OutDataType, PassThrough, 256, 1, 32, 32, 3, S<1, 32, 8>, S<0, 1, 2>, 2, 1, 2, 1>;
// clang-format on
#include "run_permute_element_example.inc"
int main() { return !run_permute_element_example({1, 32000, 80}, {0, 2, 1}); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using DataType = F16;
using BundleType = F64;
static_assert(sizeof(BundleType) % sizeof(DataType) == 0);
// clang-format off
using DevicePermuteInstance = ck::tensor_operation::device::DevicePermuteImpl
// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | |
< 3, BundleType, BundleType, PassThrough, 256, 1, 32, 32, 5, S<1, 32, 8>, S<0, 1, 2>, 2, 1, 4, 1>;
// clang-format on
#include "run_permute_bundle_example.inc"
int main() { return !run_permute_bundle_example({1, 80, 32000}, {0, 2, 1}); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using InDataType = F16;
using OutDataType = F16;
// clang-format off
using DevicePermuteInstance = ck::tensor_operation::device::DevicePermuteImpl
// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | |
< 3, InDataType, OutDataType, PassThrough, 128, 4, 16, 8, 6, S<2, 16, 4>, S<0, 1, 2>, 2, 1, 2, 1>;
// clang-format on
#include "run_permute_element_example.inc"
int main() { return !run_permute_element_example({121, 768, 80}, {0, 2, 1}); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool run_permute_bundle(const Problem& problem)
{
const auto& input_bundle_shape = problem.shape;
const auto& input_bundle_axes = problem.axes;
const auto output_bundle_shape = transpose(input_bundle_shape, input_bundle_axes);
Tensor<BundleType> input_bundle_tensor(input_bundle_shape);
Tensor<BundleType> output_bundle_tensor(output_bundle_shape);
// initialize tensor by assigning DataType values
ck::utils::FillUniformDistribution<DataType>{-1.f, 1.f}(input_bundle_tensor.AsSpan<DataType>());
DeviceMem input_device_buf(input_bundle_tensor.GetElementSpaceSizeInBytes());
DeviceMem output_device_buf(output_bundle_tensor.GetElementSpaceSizeInBytes());
using std::data;
input_device_buf.ToDevice(data(input_bundle_tensor));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
auto permute = DevicePermuteInstance{};
auto argument = permute.MakeArgument(to_array(input_bundle_shape),
to_array(input_bundle_tensor.GetStrides()),
to_array(output_bundle_shape),
to_array(output_bundle_tensor.GetStrides()),
input_device_buf.GetDeviceBuffer(),
output_device_buf.GetDeviceBuffer(),
PassThrough{});
if(!permute.IsSupportedArgument(argument))
{
std::cerr << "The runtime parameters seems not supported by the device instance, exiting!"
<< std::endl;
return false;
};
auto invoker = permute.MakeInvoker();
float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
std::cout << "Perf: " << ave_time << " ms" << std::endl;
output_device_buf.FromDevice(data(output_bundle_tensor));
constexpr std::size_t NumElemsInBundle = sizeof(BundleType) / sizeof(DataType);
// extend tensor shape from [N, H, W] to [N, H, W, NumElemsInBundle]
// axes from [0, 2, 1] to [0, 2, 1, 3]
const auto input_shape = extend_shape(input_bundle_shape, NumElemsInBundle);
const auto input_axes = extend_axes(input_bundle_axes);
using std::begin;
Tensor<DataType> input_tensor(input_shape);
ranges::copy(input_bundle_tensor.AsSpan<const DataType>(), begin(input_tensor));
Tensor<DataType> output_tensor(transpose(input_shape, input_axes));
if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor))
{
return false;
}
return ck::utils::check_err(output_bundle_tensor.AsSpan<const DataType>(),
output_tensor.AsSpan<const DataType>(),
"Error: incorrect results in output tensor",
1e-6,
1e-6);
}
bool run_permute_bundle_example(const Problem::Shape& shape, const Problem::Axes& axes)
{
return run_permute_bundle(Problem{shape, axes});
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool run_permute_element(const Problem& problem)
{
const auto& input_shape = problem.shape;
const auto& input_axes = problem.axes;
const auto output_shape = transpose(input_shape, input_axes);
Tensor<InDataType> input_tensor(input_shape);
Tensor<OutDataType> output_tensor(output_shape);
ck::utils::FillUniformDistribution<InDataType>{-1.f, 1.f}(input_tensor);
DeviceMem input_device_buf(input_tensor.GetElementSpaceSizeInBytes());
DeviceMem output_device_buf(output_tensor.GetElementSpaceSizeInBytes());
using std::data;
input_device_buf.ToDevice(data(input_tensor));
static_assert(std::is_default_constructible_v<DevicePermuteInstance>);
auto permute = DevicePermuteInstance{};
auto argument = permute.MakeArgument(to_array(input_shape),
to_array(input_tensor.GetStrides()),
to_array(output_shape),
to_array(output_tensor.GetStrides()),
input_device_buf.GetDeviceBuffer(),
output_device_buf.GetDeviceBuffer(),
PassThrough{});
if(!permute.IsSupportedArgument(argument))
{
std::cerr << "The runtime parameters seems not supported by the device instance, exiting!"
<< std::endl;
return false;
};
auto invoker = permute.MakeInvoker();
float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
std::cout << "Perf: " << ave_time << " ms" << std::endl;
output_device_buf.FromDevice(data(output_tensor));
Tensor<OutDataType> output_tensor_host(output_shape);
if(!host_permute(input_tensor, input_axes, PassThrough{}, output_tensor_host))
{
return false;
}
return ck::utils::check_err(output_tensor.AsSpan<const OutDataType>(),
output_tensor_host.AsSpan<const OutDataType>(),
"Error: incorrect results in output tensor",
1e-6,
1e-6);
}
bool run_permute_element_example(const Problem::Shape& shape, const Problem::Axes& axes)
{
return run_permute_element(Problem{shape, axes});
}
...@@ -649,6 +649,9 @@ struct BlockwiseGemmXdlops_v2 ...@@ -649,6 +649,9 @@ struct BlockwiseGemmXdlops_v2
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc, FloatAcc,
MRepeat * NRepeat, MRepeat * NRepeat,
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#pragma once #pragma once
#include <cmath>
#include <string> #include <string>
#include "ck/stream_config.hpp" #include "ck/stream_config.hpp"
......
...@@ -503,13 +503,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -503,13 +503,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!DeviceOp::IsSupportedArgument(arg))
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! unsupported argument");
} }
const index_t grid_size = const index_t grid_size =
......
...@@ -222,14 +222,9 @@ struct DeviceElementwise ...@@ -222,14 +222,9 @@ struct DeviceElementwise
} }
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override static bool IsSupportedArgument(const Argument& arg)
{ {
const Argument* pArg = dynamic_cast<const Argument*>(p_arg); if(arg.lengths_.back() % MPerThread != 0)
if(pArg == nullptr)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false; return false;
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths, auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
...@@ -247,19 +242,40 @@ struct DeviceElementwise ...@@ -247,19 +242,40 @@ struct DeviceElementwise
bool valid = true; bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) { static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid( if(!IsScalarPerVectorValid(
pArg->lengths_, pArg->inStridesArray_[I.value], InScalarPerVectorSeq::At(I))) arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
valid = false; valid = false;
}); });
static_for<0, NumOutput, 1>{}([&](auto I) { static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid( if(!IsScalarPerVectorValid(
pArg->lengths_, pArg->outStridesArray_[I.value], OutScalarPerVectorSeq::At(I))) arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
valid = false; valid = false;
}); });
return valid; return valid;
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto
MakeArgument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
{
return Argument{lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op};
}
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths, MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray, const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<>
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType,
typename AElementwiseOperation,
typename B0ElementwiseOperation,
typename Acc0ElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator
{
struct ProblemDesc
{
// Overall problem shape
index_t M;
index_t N;
index_t K;
index_t O;
index_t Batch;
// Stride for A/B0/B1; layout determined by template args
index_t StrideA;
index_t StrideB0;
index_t StrideB1;
index_t BatchStrideA;
index_t BatchStrideB0;
index_t BatchStrideB1;
// Lengths and strides for output C
std::vector<index_t> c_gs_ms_os_lengths;
std::vector<index_t> c_gs_ms_os_strides;
};
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*> p_a_vec,
std::vector<const void*> p_b0_vec,
std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec,
std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op,
B0ElementwiseOperation b0_element_op,
Acc0ElementwiseOperation acc0_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename GroupKernelArg,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1(
const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
const auto arg_ptr = reinterpret_cast<const GroupKernelArg*>(
cast_pointer_to_generic_address_space(group_kernel_args));
index_t left = 0;
index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_id >= arg_ptr[group_id].block_start_ &&
block_id < arg_ptr[group_id].block_end_)) &&
left <= right)
{
if(block_id < arg_ptr[group_id].block_start_)
{
right = group_id;
}
else
{
left = group_id;
}
group_id = index_t((left + right) / 2);
}
// per-group batch offset
const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_;
const index_t g_idx = __builtin_amdgcn_readfirstlane(
(block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx)));
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
GridwiseGemm::template Run<HasMainKBlockLoop>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset,
p_shared,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].block_2_ctile_map_);
#else
ignore = group_kernel_args;
ignore = group_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = acc_element_op;
ignore = b1_element_op;
ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template <typename ALayout,
typename BLayout, // B0Layout
typename B1Layout,
typename CPermuteNumDims_G_M_Gemm1N, // Sequence<NumDimG, NumDimM, NumDimGemm1N>
typename ADataType,
typename BDataType,
typename B1DataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename AccElementwiseOperation,
typename B1ElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock, // Gemm0NPerBlock
index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t AK1,
index_t BK1,
index_t B1K1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
index_t Gemm1NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
typename B1BlockTransferThreadClusterArrangeOrder,
typename B1BlockTransferSrcAccessOrder,
index_t B1BlockTransferSrcVectorDim,
index_t B1BlockTransferSrcScalarPerVector,
index_t B1BlockTransferDstScalarPerVector_BK1,
bool B1BlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceGroupedGemmSoftmaxGemmPermute<ALayout,
BLayout,
B1Layout,
CPermuteNumDims_G_M_Gemm1N,
ADataType,
BDataType,
B1DataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>
{
using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle;
using ProblemDesc =
typename DeviceGroupedGemmSoftmaxGemmPermute<ALayout,
BLayout,
B1Layout,
CPermuteNumDims_G_M_Gemm1N,
ADataType,
BDataType,
B1DataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation>::ProblemDesc;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder =
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
// FIXME: pad K
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b1_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw);
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0);
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1);
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_ms_ns_lengths = to_tuple(
c_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_ms_ns_strides = to_tuple(
c_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for M0, M1, ...
constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds =
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_ms_ns_lengths, nDimIds);
// naive tensor C[M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_ms_ns =
make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor(
c_grid_desc_ms_ns,
make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)),
make_tuple(mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
}
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{
constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0);
constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1);
constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N
assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
const auto to_tuple = [&](auto& vec, auto start, auto end) {
return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
};
const auto c_gs_ms_ns_lengths =
to_tuple(c_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
const auto c_gs_ms_ns_strides =
to_tuple(c_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
// dimension Ids for G0, G1, ...
constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
// dimension Ids for M0, M1, ...
constexpr auto mDimIds =
typename arithmetic_sequence_gen<NumDimG, NumDimG + NumDimM, 1>::type{};
// dimension Ids for N0, N1, ...
constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
NumDimG + NumDimM + NumDimN,
1>::type{};
// lengths for G0, G1, ...
const auto gLengths = get_container_subset(c_gs_ms_ns_lengths, gDimIds);
// lengths for M0, M1, ...
const auto mLengths = get_container_subset(c_gs_ms_ns_lengths, mDimIds);
// lengths for K0, K1, ...
const auto nLengths = get_container_subset(c_gs_ms_ns_lengths, nDimIds);
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const auto c_grid_desc_gs_ms_ns =
make_naive_tensor_descriptor(c_gs_ms_ns_lengths, c_gs_ms_ns_strides);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const auto c_grid_desc_g_mraw_nraw =
transform_tensor_descriptor(c_grid_desc_gs_ms_ns,
make_tuple(make_merge_transform(gLengths),
make_merge_transform(mLengths),
make_merge_transform(nLengths)),
make_tuple(gDimIds, mDimIds, nDimIds),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// this desc is only for calculating batch offset so no padding needed
return c_grid_desc_g_mraw_nraw;
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {}));
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {}));
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
index_t BatchStrideB1,
CGridDesc_G_M_N c_grid_desc_g_m_n)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideB1_(BatchStrideB1),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
{
}
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
}
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
{
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
private:
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideB1_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
true,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
true,
BBlockLdsExtraN,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
false,
B1BlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
matrix_padder.PadN>;
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
struct GroupKernelArg
{
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
const B1DataType* p_b1_grid_;
CDataType* p_c_grid_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
// batch & stride
index_t num_blocks_per_batch_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// block-to-c-tile map
Block2CTileMap block_2_ctile_map_;
index_t block_start_, block_end_;
};
struct GroupDeviceArg
{
// problem definiton
index_t M;
index_t N;
index_t K;
index_t O;
// Strides for the last dimensions of C for sanity check of vector load/store
index_t c_extent_lowest_;
index_t c_stride_lowest_;
CGridDesc_M_N c_grid_desc_m_n_;
};
// Argument
// FIXME: constness
struct Argument : public BaseArgument
{
Argument(std::vector<const void*> p_a_vec,
std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec,
std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
: a_element_op_{a_element_op},
b_element_op_{b_element_op},
acc_element_op_{acc_element_op},
b1_element_op_{b1_element_op},
c_element_op_{c_element_op}
{
group_count_ = problem_desc_vec.size();
if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() &&
group_count_ == p_b1_vec.size() && group_count_ == p_c_vec.size()))
{
throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size");
}
grid_size_ = 0;
for(std::size_t i = 0; i < group_count_; i++)
{
const auto p_a_grid = static_cast<const ADataType*>(p_a_vec[i]);
const auto p_b_grid = static_cast<const BDataType*>(p_b_vec[i]);
const auto p_b1_grid = static_cast<const B1DataType*>(p_b1_vec[i]);
const auto p_c_grid = static_cast<CDataType*>(p_c_vec[i]);
const auto a_grid_desc_ak0_m_ak1 = DeviceOp::MakeAGridDescriptor_AK0_M_AK1(
problem_desc_vec[i].M, problem_desc_vec[i].K, problem_desc_vec[i].StrideA);
const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
problem_desc_vec[i].K, problem_desc_vec[i].N, problem_desc_vec[i].StrideB0);
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
problem_desc_vec[i].N, problem_desc_vec[i].O, problem_desc_vec[i].StrideB1);
const auto c_grid_desc_m_n = DeviceOp::MakeCGridDescriptor_M_N(
problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n);
const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart);
const index_t grid_size_grp = block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) *
problem_desc_vec[i].Batch;
const index_t BlockEnd = grid_size_ + grid_size_grp;
// batch stride
// TODO ANT: only keep batch stride in tensor desc to reduce scalar cache pressure
const auto c_grid_desc_g_m_n = DeviceOp::MakeCGridDescriptor_G_M_N(
problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides);
const auto compute_base_ptr_of_batch =
ComputeBasePtrOfStridedBatch(problem_desc_vec[i].BatchStrideA,
problem_desc_vec[i].BatchStrideB0,
problem_desc_vec[i].BatchStrideB1,
c_grid_desc_g_m_n);
grid_size_ += grid_size_grp;
group_kernel_args_.push_back({p_a_grid,
p_b_grid,
p_b1_grid,
p_c_grid,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n),
compute_base_ptr_of_batch,
block_2_ctile_map,
BlockStart,
BlockEnd});
group_device_args_.push_back({problem_desc_vec[i].M,
problem_desc_vec[i].N,
problem_desc_vec[i].K,
problem_desc_vec[i].O,
problem_desc_vec[i].c_gs_ms_os_lengths.back(),
problem_desc_vec[i].c_gs_ms_os_strides.back(),
c_grid_desc_m_n});
}
}
std::vector<GroupKernelArg> group_kernel_args_;
std::vector<GroupDeviceArg> group_device_args_;
std::size_t group_count_;
index_t grid_size_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
AccElementwiseOperation acc_element_op_;
B1ElementwiseOperation b1_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!DeviceOp::IsSupportedArgument(arg))
{
throw std::runtime_error("wrong! unsupported argument");
}
bool all_has_main_k_block_loop = true;
bool some_has_main_k_block_loop = false;
for(std::size_t i = 0; i < arg.group_count_; i++)
{
const auto K = arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K);
all_has_main_k_block_loop &= y;
some_has_main_k_block_loop |= y;
}
hipGetErrorString(hipMemcpy(arg.p_workspace_,
arg.group_kernel_args_.data(),
arg.group_kernel_args_.size() * sizeof(GroupKernelArg),
hipMemcpyHostToDevice));
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel =
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1<GridwiseGemm,
GroupKernelArg,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
has_main_k_block_loop_>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.group_count_,
arg.a_element_op_,
arg.b_element_op_,
arg.acc_element_op_,
arg.b1_element_op_,
arg.c_element_op_);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
if(all_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
}
else if(!some_has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, false>{});
}
else
{
throw std::runtime_error("wrong! all gemm problems have to simultaneously meet "
"has_main_k_block_loop or no_main_k_block_loop");
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
bool all_has_main_k_block_loop = true;
bool some_has_main_k_block_loop = false;
for(std::size_t i = 0; i < arg.group_count_; i++)
{
const auto& kernel_arg = arg.group_kernel_args_[i];
const auto& device_arg = arg.group_device_args_[i];
// Check if C permute dimension matches GEMM + GEMM shape
const index_t c_m = device_arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = device_arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_m == a_m && c_gemm1n == b1_gemm1n))
{
return false;
}
// Check if having main loop
const auto K = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) *
kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K);
all_has_main_k_block_loop &= y;
some_has_main_k_block_loop |= y;
// Note: we need raw lengths since threadwise copy can not handle vector load when
// part of vector is out of bounds
const auto MRaw = device_arg.M;
const auto NRaw = device_arg.N;
const auto KRaw = device_arg.K;
const auto Gemm1NRaw = device_arg.O;
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest = device_arg.c_extent_lowest_;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
// Check vector store requirement; assumes last dimension in N to be contiguous
if(device_arg.c_stride_lowest_ != 1)
{
return false;
}
if(!GridwiseGemm::CheckValidity(kernel_arg.a_grid_desc_ak0_m_ak1_,
kernel_arg.b_grid_desc_bk0_n_bk1_,
kernel_arg.b1_grid_desc_bk0_n_bk1_,
device_arg.c_grid_desc_m_n_,
kernel_arg.block_2_ctile_map_))
{
return false;
}
}
// all gemm problems have to simultaneously meet has_main_k_block_loop or
// no_main_k_block_loop
if(!(all_has_main_k_block_loop || !some_has_main_k_block_loop))
{
return false;
}
return true;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(std::vector<const void*> p_a_vec,
std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec,
std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a_vec,
p_b_vec,
p_b1_vec,
p_c_vec,
problem_desc_vec,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*> p_a_vec,
std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec,
std::vector<ProblemDesc> problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
AccElementwiseOperation acc_element_op,
B1ElementwiseOperation b1_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(p_a_vec,
p_b_vec,
p_b1_vec,
p_c_vec,
problem_desc_vec,
a_element_op,
b_element_op,
acc_element_op,
b1_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerBlock << ", "
<< Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", "
<< B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ">";
// clang-format on
return str.str();
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GroupKernelArg);
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <cmath>
#include <memory>
#include <type_traits>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t NumDim, typename InDataType, typename OutDataType, typename ElementwiseOperation>
struct DevicePermute : BaseOperator
{
using Lengths = std::array<index_t, NumDim>;
using Strides = Lengths;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const Lengths& in_lengths,
const Strides& in_strides,
const Lengths& out_lengths,
const Strides& out_strides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include <utility>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_permute.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_permute.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Swap last 2 dimensions
// input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]]
// ^^^^^^^^^^^
// output shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-1], d[NumDim-2]]
// ^^^^^^^^^^^
template <index_t NumDim,
typename InDataType,
typename OutDataType,
typename ElementwiseOperation,
index_t BlockSize,
index_t NPerBlock,
index_t HPerBlock,
index_t WPerBlock,
index_t InBlockLdsExtraW,
typename InBlockTransferThreadClusterLengths,
typename InBlockTransferThreadClusterArrangeOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector>
struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>
{
using BaseType = DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>;
using typename BaseType::Lengths;
using typename BaseType::Strides;
static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim);
static_assert((NumDim - 2) <= DstVectorDim && DstVectorDim < NumDim);
static_assert(SrcVectorDim != DstVectorDim);
template <index_t N = NumDim>
static auto ConvertArrayToTuple(const std::array<index_t, NumDim>& array)
{
static_assert(1 <= N && N <= NumDim);
return generate_tuple([&](auto I) { return array[I]; }, Number<N>{});
}
static auto MakeDescriptor_N_H_W(const Lengths& lengths, const Strides& stride)
{
// create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
// d[NumDim-1]]
const auto desc =
make_naive_tensor_descriptor(ConvertArrayToTuple(lengths), ConvertArrayToTuple(stride));
// merge nd to 3d descriptor, shape: [(d[0] * d[1] * d[2] * ... * d[NumDim-3]), d[NumDim-2],
// d[NumDim-1]]
// => [N, H, W]
const index_t H = *std::next(rbegin(lengths));
const index_t W = *rbegin(lengths);
const auto desc_n_h_w = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(ConvertArrayToTuple<NumDim - 2>(lengths)),
make_pass_through_transform(H),
make_pass_through_transform(W)),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim - 2>{}),
Sequence<NumDim - 2>{},
Sequence<NumDim - 1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return PadTensorDescriptor(
desc_n_h_w, make_tuple(NPerBlock, HPerBlock, WPerBlock), Sequence<true, true, true>{});
}
using InGridDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1}));
using OutGridDesc = InGridDesc;
using GridwisePermute = GridwisePermute<
InGridDesc,
OutGridDesc,
InDataType,
OutDataType,
ElementwiseOperation,
BlockSize,
NPerBlock,
HPerBlock,
WPerBlock,
InBlockLdsExtraW,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
SrcVectorDim - (NumDim - 3), // calculate new SrcVectorDim for the merged descriptor
DstVectorDim - (NumDim - 3), // calculate new DstVectorDim for the merged descriptor
SrcScalarPerVector,
DstScalarPerVector>;
using Block2TileMap = typename GridwisePermute::DefaultBlock2TileMap;
struct Argument : public BaseArgument
{
Argument(const Lengths& in_lengths,
const Strides& in_strides,
const Lengths& out_lengths,
const Strides& out_strides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op)
: in_dev_buffer_(static_cast<const InDataType*>(in_dev_buffer)),
out_dev_buffer_(static_cast<OutDataType*>(out_dev_buffer)),
in_grid_desc_(MakeDescriptor_N_H_W(in_lengths, in_strides)),
out_grid_desc_(MakeDescriptor_N_H_W(out_lengths, out_strides)),
in_lengths_(in_lengths),
in_strides_(in_strides),
out_lengths_(out_lengths),
out_strides_(out_strides),
elementwise_op_(elementwise_op),
block_2_tile_map_(GridwisePermute::MakeDefaultBlock2TileMap(in_grid_desc_))
{
}
const InDataType* in_dev_buffer_;
OutDataType* out_dev_buffer_;
InGridDesc in_grid_desc_;
OutGridDesc out_grid_desc_;
Lengths in_lengths_;
Strides in_strides_;
Lengths out_lengths_;
Strides out_strides_;
ElementwiseOperation elementwise_op_;
Block2TileMap block_2_tile_map_;
};
struct Invoker : BaseInvoker
{
static float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const index_t grid_size = arg.block_2_tile_map_.CalculateGridSize(arg.in_grid_desc_);
const auto kernel = kernel_nd_permute<GridwisePermute,
InGridDesc,
OutGridDesc,
InDataType,
OutDataType,
ElementwiseOperation,
Block2TileMap>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.in_grid_desc_,
arg.out_grid_desc_,
arg.in_dev_buffer_,
arg.out_dev_buffer_,
arg.elementwise_op_,
arg.block_2_tile_map_);
return elapsed_time;
}
float Run(const BaseArgument* arg,
const StreamConfig& stream_config = StreamConfig{}) override final
{
const auto* const argument = dynamic_cast<const Argument*>(arg);
if(!argument)
{
return NAN;
}
return Run(*argument, stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
constexpr auto GetPaddedLength = [](index_t length, index_t tile_length) {
return math::integer_divide_ceil(length, tile_length) * tile_length;
};
constexpr auto IsScalarPerVectorValid =
[](index_t length, index_t stride, index_t scalar_per_vector) {
if(stride == 1 && length % scalar_per_vector == 0)
{
return true;
}
else if(stride != 1 && scalar_per_vector == 1)
{
return true;
}
return false;
};
return IsScalarPerVectorValid(arg.in_lengths_[SrcVectorDim],
arg.in_strides_[SrcVectorDim],
SrcScalarPerVector) &&
IsScalarPerVectorValid(
GetPaddedLength(arg.in_lengths_[SrcVectorDim],
(SrcVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
arg.in_strides_[SrcVectorDim],
SrcScalarPerVector) &&
IsScalarPerVectorValid(arg.out_lengths_[DstVectorDim],
arg.out_strides_[DstVectorDim],
DstScalarPerVector) &&
IsScalarPerVectorValid(
GetPaddedLength(arg.out_lengths_[DstVectorDim],
(DstVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
arg.in_strides_[DstVectorDim],
DstScalarPerVector) &&
GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_);
};
// override methods inherited from 'BaseOperator'
bool IsSupportedArgument(const BaseArgument* arg) override final
{
const auto* const argument = dynamic_cast<const Argument*>(arg);
if(!argument)
{
return false;
}
return IsSupportedArgument(*argument);
}
// override methods inherited from 'DevicePermute'
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const Lengths& in_lengths,
const Strides& in_strides,
const Lengths& out_lengths,
const Strides& out_strides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op) override final
{
return std::make_unique<Argument>(in_lengths,
in_strides,
out_lengths,
out_strides,
in_dev_buffer,
out_dev_buffer,
elementwise_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override final
{
return std::make_unique<Invoker>();
};
// other constructor methods
template <typename... Args>
static std::enable_if_t<std::is_constructible_v<Argument, Args...>, Argument>
MakeArgument(Args&&... args) noexcept(std::is_nothrow_constructible_v<Argument, Args...>)
{
return Argument{std::forward<Args>(args)...};
}
static std::enable_if_t<std::is_default_constructible_v<Invoker>, Invoker>
MakeInvoker() noexcept(std::is_nothrow_default_constructible_v<Invoker>)
{
return Invoker{};
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -486,4 +486,48 @@ __host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx, ...@@ -486,4 +486,48 @@ __host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx,
return is_valid; return is_valid;
} }
// This wrapper class is for grouped gemm where it subtracts blockIdx by a value so that the
// workgroups assigned to a given gemm problem have top index offsetted to range [0,
// grid_size_per_gemm]
template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMap
{
using underlying_type = UnderlyingBlockToCTileMap;
OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start)
{
block_to_ctile_map_ = block_to_ctile_map;
block_start_ = block_start;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] - block_start_));
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
template <typename CGridDesc_M_N>
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
}
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_;
};
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment