Commit ae20247a authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin' into aosewski/ggemm_multi_d2

parents d1f7a3cf a776978c
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_gemm_add_relu_impl.hpp"
#include "profiler_operation_registry.hpp"
#define OP_NAME "gemm_add_relu"
#define OP_DESC "GEMM+Add+ReLU"
using INT8 = int8_t;
using BF16 = ck::bhalf_t;
int profile_gemm_add_relu(int argc, char* argv[])
{
enum struct MatrixLayout
{
MK_KN_MN_MN, // 0
MK_NK_MN_MN, // 1
KM_KN_MN_MN, // 2
KM_NK_MN_MN, // 3
};
enum struct MatrixDataType
{
F16_INT8_F16_F16, // 0
BF16_INT8_BF16_BF16, // 1
};
if(argc != 15)
{
// clang-format off
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: f16&i8 1: bf16&i8)\n");
printf("arg3: matrix layout (0: E[m, n] = ReLU(A[m, k] * B[k, n] + D0[m, n]);\n");
printf(" 1: E[m, n] = ReLU(A[m, k] * B[n, k] + D0[m, n]);\n");
printf(" 2: E[m, n] = ReLU(A[k, m] * B[k, n] + D0[m, n]);\n");
printf(" 3: E[m, n] = ReLU(A[k, m] * B[n, k] + D0[m, n]))\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideD0, StrideE\n");
// clang-format on
exit(1);
}
const auto data_type = static_cast<MatrixDataType>(std::stoi(argv[2]));
const auto layout = static_cast<MatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const int M = std::stoi(argv[8]);
const int N = std::stoi(argv[9]);
const int K = std::stoi(argv[10]);
const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]);
const int StrideD0 = std::stoi(argv[13]);
const int StrideE = std::stoi(argv[14]);
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
// using Col = ck::tensor_layout::gemm::ColumnMajor;
auto profile = [&](auto a_type,
auto b_type,
auto acc_type,
auto d0_type,
auto e_type,
auto a_layout,
auto b_layout,
auto d0_layout,
auto e_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using AccDataType = decltype(acc_type);
using D0DataType = decltype(d0_type);
using EDataType = decltype(e_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using D0Layout = decltype(d0_layout);
using ELayout = decltype(e_layout);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
const int DefaultStrideE = ck::is_same_v<ELayout, Row> ? N : M;
bool pass = ck::profiler::profile_gemm_add_relu_impl<ADataType,
BDataType,
AccDataType,
D0DataType,
EDataType,
ALayout,
BLayout,
D0Layout,
ELayout>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideD0 < 0) ? DefaultStrideD0 : StrideD0,
(StrideE < 0) ? DefaultStrideE : StrideE);
return pass ? 0 : 1;
};
if(data_type == MatrixDataType::F16_INT8_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN)
{
return profile(F16{}, INT8{}, F32{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{});
}
else if(data_type == MatrixDataType::BF16_INT8_BF16_BF16 && layout == MatrixLayout::MK_KN_MN_MN)
{
return profile(BF16{}, INT8{}, F32{}, BF16{}, BF16{}, Row{}, Row{}, Row{}, Row{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_add_relu);
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_gemm_add_silu_impl.hpp"
#include "profiler_operation_registry.hpp"
#define OP_NAME "gemm_add_silu"
#define OP_DESC "GEMM+Add+SiLU"
using INT8 = int8_t;
using BF16 = ck::bhalf_t;
int profile_gemm_add_silu(int argc, char* argv[])
{
enum struct MatrixLayout
{
MK_KN_MN_MN, // 0
MK_NK_MN_MN, // 1
KM_KN_MN_MN, // 2
KM_NK_MN_MN, // 3
};
enum struct MatrixDataType
{
F16_INT8_F16_F16, // 0
BF16_INT8_BF16_BF16, // 1
};
if(argc != 15)
{
// clang-format off
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: f16&i8 1: bf16&i8)\n");
printf("arg3: matrix layout (0: E[m, n] = ReLU(A[m, k] * B[k, n] + D0[m, n]);\n");
printf(" 1: E[m, n] = ReLU(A[m, k] * B[n, k] + D0[m, n]);\n");
printf(" 2: E[m, n] = ReLU(A[k, m] * B[k, n] + D0[m, n]);\n");
printf(" 3: E[m, n] = ReLU(A[k, m] * B[n, k] + D0[m, n]))\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideD0, StrideE\n");
// clang-format on
exit(1);
}
const auto data_type = static_cast<MatrixDataType>(std::stoi(argv[2]));
const auto layout = static_cast<MatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const int M = std::stoi(argv[8]);
const int N = std::stoi(argv[9]);
const int K = std::stoi(argv[10]);
const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]);
const int StrideD0 = std::stoi(argv[13]);
const int StrideE = std::stoi(argv[14]);
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
// using Col = ck::tensor_layout::gemm::ColumnMajor;
auto profile = [&](auto a_type,
auto b_type,
auto acc_type,
auto d0_type,
auto e_type,
auto a_layout,
auto b_layout,
auto d0_layout,
auto e_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using AccDataType = decltype(acc_type);
using D0DataType = decltype(d0_type);
using EDataType = decltype(e_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using D0Layout = decltype(d0_layout);
using ELayout = decltype(e_layout);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
const int DefaultStrideE = ck::is_same_v<ELayout, Row> ? N : M;
bool pass = ck::profiler::profile_gemm_add_silu_impl<ADataType,
BDataType,
AccDataType,
D0DataType,
EDataType,
ALayout,
BLayout,
D0Layout,
ELayout>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideD0 < 0) ? DefaultStrideD0 : StrideD0,
(StrideE < 0) ? DefaultStrideE : StrideE);
return pass ? 0 : 1;
};
if(data_type == MatrixDataType::F16_INT8_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN)
{
return profile(F16{}, INT8{}, F32{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{});
}
else if(data_type == MatrixDataType::BF16_INT8_BF16_BF16 && layout == MatrixLayout::MK_KN_MN_MN)
{
return profile(BF16{}, INT8{}, F32{}, BF16{}, BF16{}, Row{}, Row{}, Row{}, Row{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_add_silu);
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_grouped_gemm_fixed_nk_impl.hpp"
#include "profiler_operation_registry.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
};
enum struct GemmDataType
{
BF16_I8_BF16, // 0
F16_F16_F16, // 1
F16_F8_F16, // 2
F16_I8_F16, // 3
};
#define OP_NAME "grouped_gemm_fixed_nk"
#define OP_DESC "Grouped GEMM Fixed NK"
namespace {
std::vector<int> argToIntArray(char* input)
{
std::vector<int> out;
std::istringstream in(input);
std::string item;
while(std::getline(in, item, ','))
{
out.push_back(std::stoi(item));
}
return out;
}
int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
{
if(argc < 14)
{
std::cout
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: bf16@int8; 1: fp16; 2: fp16@fp8; 3: fp16@int8)\n"
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"
<< " 1: A[m, k] * B[n, k] = C[m, n];\n"
<< "arg4: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0=n0, 1=yes)\n"
<< "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "arg15: kbatch value (default 1)\n"
<< "optional:\n"
<< "arg16: number of warm-up cycles (default 1)\n"
<< "arg17: number of iterations (default 10)\n"
<< std::endl;
exit(1);
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const auto Ms = argToIntArray(argv[8]);
const auto Ns = argToIntArray(argv[9]);
const auto Ks = argToIntArray(argv[10]);
const auto StrideAs = argToIntArray(argv[11]);
const auto StrideBs = argToIntArray(argv[12]);
const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
using F32 = float;
using F16 = ck::half_t;
using F8 = ck::f8_t;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
int n_warmup = 1;
int n_iter = 10;
if(argc == 17)
{
n_warmup = std::stoi(argv[16]);
n_iter = std::stoi(argv[17]);
}
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<BF16,
I8,
BF16,
F32,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<BF16,
I8,
BF16,
F32,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
#endif
#if defined(CK_ENABLE_FP16)
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
F16,
F16,
F32,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
F16,
F16,
F32,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
#endif
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
F8,
F16,
F32,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
F8,
F16,
F32,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
#endif
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8)
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
I8,
F16,
F32,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
I8,
F16,
F32,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
1,
n_warmup,
n_iter);
}
#endif
else
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
}
return 0;
}
} // anonymous namespace
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_fixed_nk);
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_permute_scale_impl.hpp"
#include "profiler_operation_registry.hpp"
namespace {
enum struct DataType
{
F32_F32, // 0
F16_F16 // 1
};
#define OP_NAME "permute_scale"
#define OP_DESC "Permute Scale"
static void print_helper_msg()
{
std::cout
// clang-format off
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: Input fp32, Output fp32\n"
<< " 1: Input fp16, Output fp16\n"
<< "arg4: verification (0: no, 1: yes)\n"
<< "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0: no, 1: yes)\n"
<< "from arg8: tensor lengths\n"
<< " input strides\n"
<< " output strides\n" << std::endl;
// clang-format on
}
} // namespace
int profile_permute_scale(int argc, char* argv[])
{
constexpr int control_argc = 7;
const int dims_argc = argc - control_argc;
// Number of lenghs, input strides and outputs strides must be equal
if(argc < control_argc && dims_argc % 3 != 0)
{
print_helper_msg();
return 1;
}
const auto data_type = static_cast<DataType>(std::stoi(argv[2]));
const bool do_verification = std::stoi(argv[3]);
const int init_method = std::stoi(argv[4]);
const bool do_log = std::stoi(argv[5]);
const bool time_kernel = std::stoi(argv[6]);
const int num_dims = dims_argc / 3;
std::vector<ck::index_t> lengths(num_dims);
std::vector<ck::index_t> input_strides(num_dims);
std::vector<ck::index_t> output_strides(num_dims);
for(int i = 0; i < num_dims; i++)
{
lengths[i] = std::stoi(argv[control_argc + i]);
input_strides[i] = std::stoi(argv[control_argc + num_dims + i]);
output_strides[i] = std::stoi(argv[control_argc + 2 * num_dims + i]);
}
using F32 = float;
using F16 = ck::half_t;
constexpr auto I1 = ck::Number<1>{};
constexpr auto I2 = ck::Number<2>{};
constexpr auto I3 = ck::Number<3>{};
constexpr auto I4 = ck::Number<4>{};
constexpr auto I5 = ck::Number<5>{};
constexpr auto I6 = ck::Number<6>{};
auto profile = [&](auto num_dim_tmp, auto in_type, auto out_type) {
constexpr ck::index_t NDim = num_dim_tmp.value;
using InDataType = decltype(in_type);
using OutDataType = decltype(out_type);
bool pass =
ck::profiler::profile_permute_scale_impl<InDataType, OutDataType, NDim>(do_verification,
init_method,
do_log,
time_kernel,
lengths,
input_strides,
output_strides);
return pass ? 0 : 1;
};
if(num_dims == 1)
{
if(data_type == DataType::F32_F32)
{
return profile(I1, F32{}, F32{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I1, F16{}, F16{});
}
}
else if(num_dims == 2)
{
if(data_type == DataType::F32_F32)
{
return profile(I2, F32{}, F32{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I2, F16{}, F16{});
}
}
else if(num_dims == 3)
{
if(data_type == DataType::F32_F32)
{
return profile(I3, F32{}, F32{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I3, F16{}, F16{});
}
}
else if(num_dims == 4)
{
if(data_type == DataType::F32_F32)
{
return profile(I4, F32{}, F32{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I4, F16{}, F16{});
}
}
else if(num_dims == 5)
{
if(data_type == DataType::F32_F32)
{
return profile(I5, F32{}, F32{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I5, F16{}, F16{});
}
}
else if(num_dims == 6)
{
if(data_type == DataType::F32_F32)
{
return profile(I6, F32{}, F32{});
}
else if(data_type == DataType::F16_F16)
{
return profile(I6, F16{}, F16{});
}
}
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_permute_scale);
......@@ -122,6 +122,7 @@ add_subdirectory(space_filling_curve)
add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm)
add_subdirectory(gemm_add)
add_subdirectory(gemm_layernorm)
add_subdirectory(gemm_split_k)
add_subdirectory(gemm_reduce)
......
add_gtest_executable(test_gemm_add test_gemm_add.hpp)
target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance)
add_gtest_executable(test_gemm_add_relu test_gemm_add_relu.cpp)
target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance)
add_gtest_executable(test_gemm_add_silu test_gemm_add_silu.cpp)
target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance)
add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu.cpp)
target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance)
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_add_impl.hpp"
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using I8 = int8_t;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
template <typename Tuple>
class TestGemmAdd : public ::testing::Test
{
protected:
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
using AccDataType = std::tuple_element_t<2, Tuple>;
using D0DataType = std::tuple_element_t<3, Tuple>;
using EDataType = std::tuple_element_t<4, Tuple>;
using ALayout = std::tuple_element_t<5, Tuple>;
using BLayout = std::tuple_element_t<6, Tuple>;
using D0Layout = std::tuple_element_t<7, Tuple>;
using ELayout = std::tuple_element_t<8, Tuple>;
constexpr static auto ProfileGemmAddImpl = ck::profiler::profile_gemm_add_impl<ADataType,
BDataType,
AccDataType,
D0DataType,
EDataType,
ALayout,
BLayout,
D0Layout,
ELayout>;
virtual decltype(ProfileGemmAddImpl) GetImpl() { return ProfileGemmAddImpl; }
void Run()
{
std::vector<std::vector<ck::index_t>> lengths = {
{16, 32, 64}, {2048, 4096, 8192}, {2048, 1024, 16}};
bool all_success = true;
for(auto length : lengths)
{
int M = length[0];
int N = length[1];
int K = length[2];
int StrideA = ck::is_same_v<ALayout, Row> ? K : M;
int StrideB = ck::is_same_v<BLayout, Row> ? N : K;
int StrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
int StrideE = ck::is_same_v<ELayout, Row> ? N : M;
all_success =
all_success &
GetImpl()(true, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideE);
}
EXPECT_TRUE(all_success);
}
};
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>>;
TYPED_TEST_SUITE(TestGemmAdd, KernelTypes);
TYPED_TEST(TestGemmAdd, Test_BF16FP16_INT8) { this->Run(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_add_fastgelu_impl.hpp"
#include "test_gemm_add.hpp"
template <typename Tuple>
class TestGemmAddFastgelu : public TestGemmAdd<Tuple>
{
private:
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
using AccDataType = std::tuple_element_t<2, Tuple>;
using D0DataType = std::tuple_element_t<3, Tuple>;
using EDataType = std::tuple_element_t<4, Tuple>;
using ALayout = std::tuple_element_t<5, Tuple>;
using BLayout = std::tuple_element_t<6, Tuple>;
using D0Layout = std::tuple_element_t<7, Tuple>;
using ELayout = std::tuple_element_t<8, Tuple>;
constexpr static auto ProfileGemmAddFastgeluImpl =
ck::profiler::profile_gemm_add_fastgelu_impl<ADataType,
BDataType,
AccDataType,
D0DataType,
EDataType,
ALayout,
BLayout,
D0Layout,
ELayout>;
decltype(ProfileGemmAddFastgeluImpl) GetImpl() override { return ProfileGemmAddFastgeluImpl; }
};
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>>;
TYPED_TEST_SUITE(TestGemmAddFastgelu, KernelTypes);
TYPED_TEST(TestGemmAddFastgelu, Test_BF16FP16) { this->Run(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_add_relu_impl.hpp"
#include "test_gemm_add.hpp"
template <typename Tuple>
class TestGemmAddRelu : public TestGemmAdd<Tuple>
{
private:
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
using AccDataType = std::tuple_element_t<2, Tuple>;
using D0DataType = std::tuple_element_t<3, Tuple>;
using EDataType = std::tuple_element_t<4, Tuple>;
using ALayout = std::tuple_element_t<5, Tuple>;
using BLayout = std::tuple_element_t<6, Tuple>;
using D0Layout = std::tuple_element_t<7, Tuple>;
using ELayout = std::tuple_element_t<8, Tuple>;
constexpr static auto ProfileGemmAddReluImpl =
ck::profiler::profile_gemm_add_relu_impl<ADataType,
BDataType,
AccDataType,
D0DataType,
EDataType,
ALayout,
BLayout,
D0Layout,
ELayout>;
decltype(ProfileGemmAddReluImpl) GetImpl() override { return ProfileGemmAddReluImpl; }
};
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>>;
TYPED_TEST_SUITE(TestGemmAddRelu, KernelTypes);
TYPED_TEST(TestGemmAddRelu, Test_BF16FP16_INT8) { this->Run(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_add_silu_impl.hpp"
#include "test_gemm_add.hpp"
template <typename Tuple>
class TestGemmAddSilu : public TestGemmAdd<Tuple>
{
private:
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
using AccDataType = std::tuple_element_t<2, Tuple>;
using D0DataType = std::tuple_element_t<3, Tuple>;
using EDataType = std::tuple_element_t<4, Tuple>;
using ALayout = std::tuple_element_t<5, Tuple>;
using BLayout = std::tuple_element_t<6, Tuple>;
using D0Layout = std::tuple_element_t<7, Tuple>;
using ELayout = std::tuple_element_t<8, Tuple>;
constexpr static auto ProfileGemmAddSiluImpl =
ck::profiler::profile_gemm_add_silu_impl<ADataType,
BDataType,
AccDataType,
D0DataType,
EDataType,
ALayout,
BLayout,
D0Layout,
ELayout>;
decltype(ProfileGemmAddSiluImpl) GetImpl() override { return ProfileGemmAddSiluImpl; }
};
using KernelTypes = ::testing::Types<std::tuple<F16, I8, F32, F16, F16, Row, Row, Row, Row>,
std::tuple<BF16, I8, F32, BF16, BF16, Row, Row, Row, Row>>;
TYPED_TEST_SUITE(TestGemmAddSilu, KernelTypes);
TYPED_TEST(TestGemmAddSilu, Test_BF16FP16_INT8) { this->Run(); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_permute_scale_impl.hpp"
#include "profiler/profile_permute_scale_impl.hpp"
using F16 = ck::half_t;
using F32 = float;
......@@ -15,15 +15,32 @@ class TestPermute : public ::testing::Test
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
void Run()
constexpr bool skip_case()
{
std::vector<std::vector<ck::index_t>> lengths = {
{4, 2, 1, 8}, {1, 1, 1, 1}, {16, 8, 32, 64}, {32, 64, 128, 128}};
#ifndef CK_ENABLE_FP16
if constexpr(ck::is_same_v<ADataType, F16> || ck::is_same_v<BDataType, F16>)
{
return true;
}
#endif
#ifndef CK_ENABLE_FP32
if constexpr(ck::is_same_v<ADataType, F32> || ck::is_same_v<BDataType, F32>)
{
return true;
}
#endif
return false;
}
for(auto length : lengths)
template <ck::index_t NDims>
void Run(std::vector<ck::index_t> lengths,
std::vector<ck::index_t> input_strides,
std::vector<ck::index_t> output_strides)
{
if(!skip_case())
{
bool success =
ck::test_permute_scale_impl<ADataType, BDataType, 4>(true, 2, false, false, length);
bool success = ck::profiler::profile_permute_scale_impl<ADataType, BDataType, NDims>(
true, 2, false, false, lengths, input_strides, output_strides);
EXPECT_TRUE(success);
}
}
......@@ -32,5 +49,52 @@ class TestPermute : public ::testing::Test
using KernelTypes = ::testing::Types<std::tuple<F16, F16>, std::tuple<F32, F32>>;
TYPED_TEST_SUITE(TestPermute, KernelTypes);
TYPED_TEST(TestPermute, Test_FP16) { this->Run(); }
TYPED_TEST(TestPermute, Test_FP32) { this->Run(); }
TYPED_TEST(TestPermute, Test1D)
{
constexpr ck::index_t NumDims = 1;
this->template Run<NumDims>({8}, {1}, {2});
this->template Run<NumDims>({8}, {2}, {1});
this->template Run<NumDims>({1}, {1}, {1});
}
TYPED_TEST(TestPermute, Test2D)
{
constexpr ck::index_t NumDims = 2;
this->template Run<NumDims>({8, 4}, {4, 1}, {1, 8});
this->template Run<NumDims>({8, 4}, {1, 8}, {4, 1});
this->template Run<NumDims>({1, 1}, {1, 1}, {1, 1});
}
TYPED_TEST(TestPermute, Test3D)
{
constexpr ck::index_t NumDims = 3;
this->template Run<NumDims>({2, 4, 4}, {16, 4, 1}, {1, 2, 8});
this->template Run<NumDims>({2, 4, 4}, {1, 2, 8}, {16, 4, 1});
this->template Run<NumDims>({1, 1, 1}, {1, 1, 1}, {1, 1, 1});
}
TYPED_TEST(TestPermute, Test4D)
{
constexpr ck::index_t NumDims = 4;
this->template Run<NumDims>({2, 4, 4, 4}, {64, 16, 4, 1}, {1, 2, 8, 32});
this->template Run<NumDims>({2, 4, 4, 4}, {1, 2, 8, 32}, {64, 16, 4, 1});
this->template Run<NumDims>({1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1});
}
TYPED_TEST(TestPermute, Test5D)
{
constexpr ck::index_t NumDims = 5;
this->template Run<NumDims>({2, 4, 4, 4, 4}, {256, 64, 16, 4, 1}, {1, 2, 8, 32, 128});
this->template Run<NumDims>({2, 4, 4, 4, 4}, {1, 2, 8, 32, 128}, {256, 64, 16, 4, 1});
this->template Run<NumDims>({1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1});
}
TYPED_TEST(TestPermute, Test6D)
{
constexpr ck::index_t NumDims = 6;
this->template Run<NumDims>(
{2, 4, 4, 4, 4, 4}, {1024, 256, 64, 16, 4, 1}, {1, 2, 8, 32, 128, 512});
this->template Run<NumDims>(
{2, 4, 4, 4, 4, 4}, {1, 2, 8, 32, 128, 512}, {1024, 256, 64, 16, 4, 1});
this->template Run<NumDims>({1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1});
}
add_gtest_executable(test_layout test_layout.cpp)
target_link_libraries(test_layout PRIVATE utility)
add_gtest_executable(test_tensor test_tensor.cpp)
target_link_libraries(test_tensor PRIVATE utility)
add_gtest_executable(test_copy test_copy.cpp)
target_link_libraries(test_copy PRIVATE utility)
add_gtest_executable(test_partition test_partition.cpp)
target_link_libraries(test_partition PRIVATE utility)
add_custom_target(test_wrapper)
add_gtest_executable(test_wrapper_layout test_wrapper_layout.cpp)
target_link_libraries(test_wrapper_layout PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_layout)
add_gtest_executable(test_wrapper_tensor test_wrapper_tensor.cpp)
target_link_libraries(test_wrapper_tensor PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_tensor)
add_gtest_executable(test_wrapper_copy test_wrapper_copy.cpp)
target_link_libraries(test_wrapper_copy PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_copy)
add_gtest_executable(test_wrapper_partition test_wrapper_partition.cpp)
target_link_libraries(test_wrapper_partition PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_partition)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR
GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR
GPU_TARGETS MATCHES "gfx942")
add_gtest_executable(test_gemm test_gemm.cpp)
target_link_libraries(test_gemm PRIVATE utility)
add_gtest_executable(test_wrapper_gemm test_wrapper_gemm.cpp)
target_link_libraries(test_wrapper_gemm PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_gemm)
endif()
......@@ -20,23 +20,25 @@
template <typename InputTensor,
typename OutputTensor,
typename BlockShape,
typename ThreadLayoutShape,
typename ThreadLayout,
bool UseOptimizedCopy>
__global__ void TestCopyDevice(const InputTensor input_tensor,
OutputTensor output_tensor,
const BlockShape tile_shape,
const ThreadLayoutShape thread_layout)
const ThreadLayout thread_layout)
{
__shared__ ck::index_t p_shared[ck::wrapper::size(tile_shape)];
const auto tensor_lds = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
p_shared, ck::wrapper::make_layout(tile_shape));
const auto block_idx = static_cast<ck::index_t>(blockIdx.x);
const auto block_idxs =
ck::make_tuple(static_cast<ck::index_t>(blockIdx.x), static_cast<ck::index_t>(blockIdx.y));
// Get local tiles for global memory
const auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx);
const auto input_local_tile =
ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs);
const auto output_local_tile =
ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx);
ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs);
// Get partition per thread
const auto input_local_partition =
......@@ -49,7 +51,7 @@ __global__ void TestCopyDevice(const InputTensor input_tensor,
// Allocate VGPR
auto tensor_vgpr =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, ck::index_t>(
layout(lds_local_partition));
ck::wrapper::make_layout(shape(lds_local_partition)));
// Perform copy
if constexpr(UseOptimizedCopy)
......@@ -99,11 +101,14 @@ void PerformCopyGlobalToGlobalViaLDS()
auto output_tensor_global = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<ck::index_t*>(out_buf.GetDeviceBuffer()), layout);
const auto thread_layout = ck::make_tuple(ck::Number<1>{}, ck::Number<32>{});
const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{});
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<1>{}, ck::Number<32>{}));
const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{});
const ck::index_t grid_size = ck::math::integer_divide_ceil(
ck::wrapper::size(input_tensor_global), ck::wrapper::size(tile_shape));
const ck::index_t grid_size_x = ck::math::integer_divide_ceil(
ck::wrapper::size<0>(input_tensor_global), ck::wrapper::size<0>(tile_shape));
const ck::index_t grid_size_y = ck::math::integer_divide_ceil(
ck::wrapper::size<1>(input_tensor_global), ck::wrapper::size<1>(tile_shape));
const auto kernel = TestCopyDevice<decltype(input_tensor_global),
decltype(output_tensor_global),
......@@ -112,7 +117,7 @@ void PerformCopyGlobalToGlobalViaLDS()
UseOptimizedCopy>;
launch_and_time_kernel(StreamConfig{},
kernel,
dim3(grid_size),
dim3(grid_size_x, grid_size_y, 1),
dim3(ck::wrapper::size(thread_layout)),
0,
input_tensor_global,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <numeric>
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <vector>
#include <gtest/gtest.h>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/operations/gemm.hpp"
#include "ck/wrapper/utils/kernel_utils.hpp"
template <typename DataType>
void CheckResult(const std::vector<DataType>& a_data,
const std::vector<DataType>& b_data,
std::vector<DataType>& c_m_n_device_result,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<DataType, DataType, DataType, float, PassThrough, PassThrough, PassThrough>;
Tensor<DataType> a_m_k(HostTensorDescriptor({M, K}));
Tensor<DataType> b_k_n(HostTensorDescriptor({K, N}, {1, K}));
Tensor<DataType> c_m_n_host_result(HostTensorDescriptor({M, N}));
a_m_k.mData = a_data;
b_k_n.mData = b_data;
auto ref_op = ReferenceGemmInstance{};
auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_op.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
EXPECT_TRUE(ck::utils::check_err(c_m_n_device_result, c_m_n_host_result.mData));
}
template <bool DoPad, typename Layout, typename PaddingDims>
__device__ auto ApplyPadding(const Layout& layout, const PaddingDims& padding_dims)
{
if constexpr(DoPad)
{
return ck::wrapper::pad(layout, padding_dims);
}
else
{
return layout;
}
}
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
typename BlockShape,
typename ThreadLayout,
bool DoPadding>
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
const void* p_b,
void* p_c,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape tile_shape,
const ThreadLayout thread_layout)
{
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
constexpr auto K1 = GemmTraits::K1;
constexpr auto K0PerBlock = KPerBlock / K1;
const auto K0 = ck::math::integer_divide_ceil(K, K1);
const auto tile_shape_k0_m_n_k1 = ck::make_tuple(K0PerBlock, MPerBlock, NPerBlock, K1);
const auto a_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
const auto b_global_layout =
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
const auto c_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
auto a_padded_global_layout =
ApplyPadding<DoPadding>(a_global_layout, ck::make_tuple(MPerBlock, KPerBlock));
auto b_padded_global_layout =
ApplyPadding<DoPadding>(b_global_layout, ck::make_tuple(NPerBlock, KPerBlock));
auto c_padded_global_layout =
ApplyPadding<DoPadding>(c_global_layout, ck::make_tuple(MPerBlock, NPerBlock));
// Reshape from M,K to K0,M,K1
const auto reshaped_dims_idxs =
ck::make_tuple(ck::Number<1>{}, ck::make_tuple(ck::Number<0>{}, ck::Number<2>{}));
auto a_padded_unmerged_global_layout =
ck::wrapper::unmerge<1>(a_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
auto b_padded_unmerged_global_layout =
ck::wrapper::unmerge<1>(b_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_a), a_padded_unmerged_global_layout);
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_b), b_padded_unmerged_global_layout);
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<DataType*>(p_c), c_padded_global_layout);
// Add extra M and N
constexpr auto a_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(K0PerBlock, MPerBlock, K1),
ck::make_tuple((MPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
constexpr auto b_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(K0PerBlock, NPerBlock, K1),
ck::make_tuple((NPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout) + NPerBlock];
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout) + NPerBlock];
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_a), a_tile_layout);
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_b), b_tile_layout);
const auto block_idxs = ck::make_tuple(ck::wrapper::slice(),
static_cast<ck::index_t>(blockIdx.x),
static_cast<ck::index_t>(blockIdx.y),
ck::wrapper::slice());
using DimAccessOrder = ck::Tuple<ck::Number<1>, ck::Number<0>, ck::Number<2>>;
constexpr ck::index_t vector_dim = 2;
auto c_global_local_tile =
ck::wrapper::make_local_tile(c_global_tensor,
tile_shape_k0_m_n_k1,
block_idxs,
make_tuple(ck::wrapper::slice(K0PerBlock),
ck::Number<1>{},
ck::Number<1>{},
ck::wrapper::slice(K1)));
auto c_global_local_partition =
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>(c_global_local_tile);
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>();
ck::wrapper::clear(c_vgpr_reg);
auto a_lds_tensor_local_partition =
ck::wrapper::make_local_partition(a_lds_tensor, thread_layout, threadIdx.x);
auto b_lds_tensor_local_partition =
ck::wrapper::make_local_partition(b_lds_tensor, thread_layout, threadIdx.x);
auto make_global_partition = [&](auto tensor, auto projection, ck::index_t i) {
const auto k_slice =
ck::make_tuple(ck::wrapper::slice(i * K0PerBlock, (i + 1) * K0PerBlock),
ck::wrapper::slice(),
ck::wrapper::slice());
auto local_tile = ck::wrapper::make_local_tile(
tensor(k_slice), tile_shape_k0_m_n_k1, block_idxs, projection);
return ck::wrapper::make_local_partition(local_tile, thread_layout, threadIdx.x);
};
auto a_global_local_partition = make_global_partition(
a_global_tensor,
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
0);
auto b_global_local_partition = make_global_partition(
b_global_tensor,
make_tuple(ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
0);
// (row-major vgpr layout)
auto a_vgpr_tensor =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
ck::wrapper::make_layout(
shape(a_global_local_partition),
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
ck::wrapper::size<2>(a_global_local_partition),
ck::wrapper::size<2>(a_global_local_partition),
ck::Number<1>{})));
auto b_vgpr_tensor =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
ck::wrapper::make_layout(
shape(b_global_local_partition),
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
ck::wrapper::size<2>(a_global_local_partition),
ck::wrapper::size<2>(a_global_local_partition),
ck::Number<1>{})));
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_global_local_partition,
a_vgpr_tensor);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_global_local_partition,
b_vgpr_tensor);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_vgpr_tensor,
a_lds_tensor_local_partition);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_vgpr_tensor,
b_lds_tensor_local_partition);
const ck::index_t num_loop =
__builtin_amdgcn_readfirstlane(ck::math::integer_divide_ceil(K, KPerBlock));
if(num_loop > 1)
{
ck::index_t i = 0;
do
{
auto a_global_local_partition_i = make_global_partition(
a_global_tensor,
make_tuple(
ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
i + 1);
auto b_global_local_partition_i = make_global_partition(
b_global_tensor,
make_tuple(
ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
i + 1);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_global_local_partition_i, a_vgpr_tensor);
ck::block_sync_lds();
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_global_local_partition_i, b_vgpr_tensor);
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
ck::block_sync_lds();
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_vgpr_tensor, a_lds_tensor_local_partition);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_vgpr_tensor, b_lds_tensor_local_partition);
++i;
} while(i < (num_loop - 1));
}
ck::block_sync_lds();
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
}
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
bool DoPadding,
typename BlockShape,
typename ThreadLayout>
void PerformGemm(const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape& tile_shape,
const ThreadLayout& thread_layout)
{
// Global memory buffers
DeviceMem a_mem(M * K * sizeof(DataType));
DeviceMem b_mem(K * N * sizeof(DataType));
DeviceMem c_mem(M * N * sizeof(DataType));
std::vector<DataType> a_data(M * K);
std::vector<DataType> b_data(K * N);
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(a_data);
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(b_data);
a_mem.ToDevice(a_data.data());
b_mem.ToDevice(b_data.data());
c_mem.SetZero();
const ck::index_t grid_size_x =
ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape));
const ck::index_t grid_size_y =
ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape));
const auto kernel =
DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayout, DoPadding>;
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
kernel,
dim3(grid_size_x, grid_size_y, 1),
dim3(ck::wrapper::size(thread_layout)),
0,
a_mem.GetDeviceBuffer(),
b_mem.GetDeviceBuffer(),
c_mem.GetDeviceBuffer(),
M,
N,
K,
tile_shape,
thread_layout);
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << std::endl;
std::vector<DataType> c_data(M * N);
c_mem.FromDevice(c_data.data());
CheckResult<DataType>(a_data, b_data, c_data, M, N, K);
}
TEST(TestGemm, Float)
{
using DataType = float;
// (dim1, dim2, dim0 thread layout)
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<16>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 4, false>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 1, true>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Int8)
{
using DataType = int8_t;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType,
ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1,
16,
false>(512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1, 1, true>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Half)
{
using DataType = ck::half_t;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<32>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 8, false>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 1, true>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Float_2x4_4x2_XdlPerWave)
{
using DataType = float;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<16>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1, 4, false>(
512, 512, 128, tile_shape, thread_layout);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
......
......@@ -29,8 +29,11 @@ TEST(TestPartition, LocalPartition)
const auto tensor =
ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(data.data(), layout);
const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{});
const auto thread_layout = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{});
const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{});
// row-major thread layout
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}, ck::Number<1>{}));
// 3d partition on 2d shape (calculate partition on 3d thread layout, and then skip first dim)
const auto thread_projection =
ck::make_tuple(ck::wrapper::slice(4), ck::Number<1>{}, ck::Number<1>{});
......@@ -70,29 +73,37 @@ TEST(TestPartition, LocalTile)
ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{}, ck::Number<2>{});
const auto block_projection =
ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(2));
constexpr ck::index_t projection_block_dim = ck::Number<2>{};
const auto num_blocks =
const auto grid_shape =
ck::make_tuple(ck::wrapper::size<0>(shape) / ck::wrapper::size<0>(block_shape),
ck::wrapper::size<1>(shape) / ck::wrapper::size<1>(block_shape),
ck::wrapper::size<2>(shape) / ck::wrapper::size<2>(block_shape));
std::vector<ck::index_t> block_idxs(ck::wrapper::size(num_blocks));
std::iota(block_idxs.begin(), block_idxs.end(), 0);
std::vector<ck::Tuple<ck::index_t, ck::index_t, ck::index_t, ck::index_t>> block_idxs;
for(int i = 0; i < ck::wrapper::size<0>(grid_shape); i++)
{
for(int j = 0; j < ck::wrapper::size<1>(grid_shape); j++)
{
for(int k = 0; k < ck::wrapper::size<2>(grid_shape); k++)
{
block_idxs.emplace_back(i, j, k, 0);
}
}
}
for(auto block_idx : block_idxs)
{
constexpr ck::index_t projection_block_dim = ck::Number<2>{};
const auto packed_tile =
ck::wrapper::make_local_tile(tensor, block_shape, block_idx, block_projection);
const auto expected_tile_size = ck::wrapper::size(block_shape) / projection_block_dim;
auto expected_tile_first_val = (block_idx % ck::wrapper::size<2>(num_blocks)) *
auto expected_tile_first_val = ck::wrapper::size<2>(block_idx) *
ck::wrapper::size<2>(block_shape) *
ck::wrapper::size<2>(strides);
block_idx /= ck::wrapper::size<2>(num_blocks);
expected_tile_first_val += (block_idx % ck::wrapper::size<1>(num_blocks)) *
expected_tile_first_val += ck::wrapper::size<1>(block_idx) *
ck::wrapper::size<1>(block_shape) *
ck::wrapper::size<1>(strides);
block_idx /= ck::wrapper::size<1>(num_blocks);
expected_tile_first_val += (block_idx % ck::wrapper::size<0>(num_blocks)) *
expected_tile_first_val += ck::wrapper::size<0>(block_idx) *
ck::wrapper::size<0>(block_shape) *
ck::wrapper::size<0>(strides);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment