Unverified Commit 9697ad4e authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into add_int8_wmma_example_instance

parents 1c97db8a 582e31e8
......@@ -64,7 +64,7 @@ int profile_groupnorm(int argc, char* argv[])
ck::DataTypeEnum data_type = ck::DataTypeEnum::Half;
bool do_verification = false;
int init_method = 0;
bool do_log = 0;
bool do_log = false;
bool time_kernel = 1;
std::vector<index_t> length = {64, 16, 16, 32, 40};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include <unordered_map>
#include "profiler/data_type_enum.hpp"
#include "profiler/profile_pool3d_fwd_impl.hpp"
#include "profiler_operation_registry.hpp"
using ck::index_t;
struct maxPoolFwdArgParser
{
std::unordered_map<std::string, std::vector<int>> long_opts = {
{"length", {}}, {"wsize", {}}, {"wstride", {}}, {"pad1", {}}, {"pad2", {}}};
bool parse_opt(int argc, char* argv[], const std::string& key, int i)
{
if(std::string("--") + key == argv[i])
{
int pos = i;
while(++i < argc && argv[i][0] != '-') {}
int end = i;
for(int j = pos + 1; j < end; j++)
{
long_opts[key].push_back(std::stoi(argv[j]));
}
return true;
}
return false;
}
void operator()(int argc, char* argv[])
{
for(auto& kv : long_opts)
{
for(int i = 1; i < argc; i++)
{
if(parse_opt(argc, argv, kv.first, i))
break;
}
}
}
};
void print_help_max_pool3d_fwd()
{
std::cout << "arg1: data type (0: fp16; 1: fp32)\n"
<< "arg2: verification (0: no; 1: yes)\n"
<< "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg4: print tensor value (0: no; 1: yes)\n"
<< "arg5: time kernel (0=no, 1=yes)\n"
<< "arg6: return index (0=no, 1=yes)\n"
<< "--length: input tensor length for NCDHW(e.g, --length 2 32 30 30 30) \n"
<< "--wsize: window size for ZYX (e.g, --wsize 2 2 2) \n"
<< "--wstride: window stride for DHW (e.g, --wstride 2 2 2) \n"
<< "--pad1: left side of padding in DHW (e.g, --pad1 1 1 1) \n"
<< "--pad2: right side of padding in DHW (e.g, --pad2 1 1 1) \n"
<< "eg: ckProfiler max_pool3d_fwd 0 1 2 0 1 0 --length 2 32 30 30 30 --wsize 2 2 2 "
"--wstride 2 2 2 --pad1 1 1 1 --pad2 1 1 1"
<< std::endl;
}
int profile_max_pool3d_fwd(int argc, char* argv[])
{
ck::DataTypeEnum data_type = ck::DataTypeEnum::Half;
bool do_verification = true;
int init_method = 0;
bool do_log = false;
bool time_kernel = true;
bool return_index = false;
std::vector<index_t> in_length = {2, 32, 30, 30, 30};
std::vector<index_t> wsize = {2, 2, 2};
std::vector<index_t> wstride = {2, 2, 2};
std::vector<index_t> pad1 = {1, 1, 1};
std::vector<index_t> pad2 = {1, 1, 1};
if(argc != 2 && argc != 30)
{
print_help_max_pool3d_fwd();
return 0;
}
else if(argc == 30)
{
data_type = static_cast<ck::DataTypeEnum>(std::stoi(argv[2]));
do_verification = std::stoi(argv[3]);
init_method = std::stoi(argv[4]);
do_log = std::stoi(argv[5]);
time_kernel = std::stoi(argv[6]);
return_index = std::stoi(argv[7]);
// parse the long options
maxPoolFwdArgParser arg_parser;
arg_parser(argc, argv);
in_length = arg_parser.long_opts["length"];
wsize = arg_parser.long_opts["wsize"];
wstride = arg_parser.long_opts["wstride"];
pad1 = arg_parser.long_opts["pad1"];
pad2 = arg_parser.long_opts["pad2"];
}
using F16 = ck::half_t;
using F32 = float;
using I32 = int32_t;
constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
if(data_type == ck::DataTypeEnum::Half)
{
if(return_index)
ck::profiler::profile_pool3d_fwd_impl<F16, F16, F16, I32, ReduceOpId, false, true>(
do_verification,
init_method,
do_log,
time_kernel,
in_length,
wsize,
wstride,
pad1,
pad2);
else
ck::profiler::profile_pool3d_fwd_impl<F16, F16, F16, I32, ReduceOpId, false, false>(
do_verification,
init_method,
do_log,
time_kernel,
in_length,
wsize,
wstride,
pad1,
pad2);
}
else if(data_type == ck::DataTypeEnum::Float)
{
if(return_index)
ck::profiler::profile_pool3d_fwd_impl<F32, F32, F32, I32, ReduceOpId, false, true>(
do_verification,
init_method,
do_log,
time_kernel,
in_length,
wsize,
wstride,
pad1,
pad2);
else
ck::profiler::profile_pool3d_fwd_impl<F32, F32, F32, I32, ReduceOpId, false, false>(
do_verification,
init_method,
do_log,
time_kernel,
in_length,
wsize,
wstride,
pad1,
pad2);
}
else
{
throw std::runtime_error("not implemented yet");
}
return 0;
}
REGISTER_PROFILER_OPERATION("max_pool3d_fwd", "max_pool3d fwd", profile_max_pool3d_fwd);
......@@ -57,6 +57,7 @@ add_subdirectory(data_type)
add_subdirectory(elementwise_normalization)
add_subdirectory(batchnorm)
add_subdirectory(contraction)
add_subdirectory(pool_fwd)
if(GPU_TARGETS MATCHES "gfx1100")
add_subdirectory(wmma_op)
endif()
add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp)
target_link_libraries(test_batched_gemm_fp16 PRIVATE utility)
target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp)
target_link_libraries(test_batched_gemm_fp16 PRIVATE utility)
target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance)
add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp)
target_link_libraries(test_batched_gemm_fp32 PRIVATE utility)
target_link_libraries(test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance)
add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp)
target_link_libraries(test_batched_gemm_fp32 PRIVATE utility)
target_link_libraries(test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance)
add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp)
target_link_libraries(test_batched_gemm_bf16 PRIVATE utility)
target_link_libraries(test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance)
add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp)
target_link_libraries(test_batched_gemm_bf16 PRIVATE utility)
target_link_libraries(test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance)
add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp)
target_link_libraries(test_batched_gemm_int8 PRIVATE utility)
target_link_libraries(test_batched_gemm_int8 PRIVATE device_batched_gemm_instance)
add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp)
target_link_libraries(test_batched_gemm_int8 PRIVATE utility)
target_link_libraries(test_batched_gemm_int8 PRIVATE device_batched_gemm_instance)
endif()
\ No newline at end of file
add_custom_target(test_batched_gemm_gemm)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_custom_target(test_batched_gemm_gemm)
add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp)
target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance)
add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16)
\ No newline at end of file
add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp)
target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance)
add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16)
endif()
\ No newline at end of file
add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp)
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility)
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp)
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility)
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance)
endif()
add_custom_target(test_batched_gemm_softmax_gemm)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_custom_target(test_batched_gemm_softmax_gemm)
add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp)
target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance)
add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16)
\ No newline at end of file
add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp)
target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance)
add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16)
endif()
\ No newline at end of file
add_custom_target(test_batched_gemm_softmax_gemm_permute)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_custom_target(test_batched_gemm_softmax_gemm_permute)
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp)
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp)
target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16)
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp)
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp)
target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16)
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp)
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16)
\ No newline at end of file
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp)
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16)
endif()
\ No newline at end of file
add_gtest_executable(test_contraction test_contraction.cpp)
add_gtest_executable(test_contraction_interface test_contraction_interface.cpp)
target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance)
target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_gtest_executable(test_contraction_interface test_contraction_interface.cpp)
target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance)
endif()
add_gtest_executable(test_convnd_bwd_data convnd_bwd_data.cpp)
target_link_libraries(test_convnd_bwd_data PRIVATE utility device_conv1d_bwd_data_instance device_conv2d_bwd_data_instance device_conv3d_bwd_data_instance)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_gtest_executable(test_convnd_bwd_data convnd_bwd_data.cpp)
target_link_libraries(test_convnd_bwd_data PRIVATE utility device_conv1d_bwd_data_instance device_conv2d_bwd_data_instance device_conv3d_bwd_data_instance)
endif()
\ No newline at end of file
add_gtest_executable(test_convnd_fwd convnd_fwd.cpp)
target_link_libraries(test_convnd_fwd PRIVATE utility device_conv2d_fwd_instance)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_gtest_executable(test_convnd_fwd convnd_fwd.cpp)
target_link_libraries(test_convnd_fwd PRIVATE utility device_conv2d_fwd_instance)
endif()
add_custom_target(test_gemm_layernorm)
add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp)
target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance)
add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_custom_target(test_gemm_layernorm)
add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp)
target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance)
add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16)
endif()
add_test_executable(test_gemm_split_k gemm_split_k.cpp)
target_link_libraries(test_gemm_split_k PRIVATE utility)
target_link_libraries(test_gemm_split_k PRIVATE device_gemm_splitk_instance)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_gtest_executable(test_gemm_splitk test_gemm_splitk.cpp)
target_link_libraries(test_gemm_splitk PRIVATE utility device_gemm_splitk_instance)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/host_gemm.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
};
template <typename T>
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
{
float max_diff = 1e-6;
for(std::size_t i = 0; i < ref.mData.size(); ++i)
{
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
if(max_diff < diff)
{
return false;
}
}
return true;
}
struct gemmArgs
{
GemmMatrixLayout layout;
int M;
int N;
int K;
int StrideA;
int StrideB;
int StrideC;
int KBatch;
};
int test_gemm(const gemmArgs& args)
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
bool a_row_major, b_row_major, c_row_major;
switch(args.layout)
{
case GemmMatrixLayout::MK_KN_MN:
a_row_major = true;
b_row_major = true;
c_row_major = true;
break;
case GemmMatrixLayout::MK_NK_MN:
a_row_major = true;
b_row_major = false;
c_row_major = true;
break;
case GemmMatrixLayout::KM_KN_MN:
a_row_major = false;
b_row_major = true;
c_row_major = true;
break;
case GemmMatrixLayout::KM_NK_MN:
a_row_major = false;
b_row_major = false;
c_row_major = true;
break;
default: printf("not supported layout"); return 1;
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, bool row_major) {
using namespace ck::literals;
if(row_major)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<float> a_m_k(f_host_tensor_descriptor(args.M, args.K, args.StrideA, a_row_major));
Tensor<float> b_k_n(f_host_tensor_descriptor(args.K, args.N, args.StrideB, b_row_major));
Tensor<float> c_m_n_host_result(
f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
Tensor<float> c_m_n_device_result(
f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
// init data
std::size_t num_thread = 1;
a_m_k.GenerateTensorValue(GeneratorTensor_2<float>{-5, 5}, num_thread);
b_k_n.GenerateTensorValue(GeneratorTensor_2<float>{-5, 5}, num_thread);
// set zero to c_device_buf
c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<float>{}, num_thread);
host_gemm_mk_kn_mn(a_m_k,
b_k_n,
c_m_n_host_result,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
DeviceMem a_device_buf(sizeof(float) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(float) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(float) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool success = false;
using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
float,
float,
float,
PassThrough,
PassThrough,
PassThrough>;
const auto gemm_ptrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemm_ptr : gemm_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(static_cast<float*>(a_device_buf.GetDeviceBuffer()),
static_cast<float*>(b_device_buf.GetDeviceBuffer()),
static_cast<float*>(c_device_buf.GetDeviceBuffer()),
args.M,
args.N,
args.K,
args.StrideA,
args.StrideB,
args.StrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
args.KBatch);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get());
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
if(!check_out(c_m_n_host_result, c_m_n_device_result))
{
success = false;
break;
}
success = true;
}
}
return success;
};
bool success = false;
if(args.layout == GemmMatrixLayout::MK_KN_MN)
{
success = test(Row{}, Row{}, Row{});
}
else if(args.layout == GemmMatrixLayout::MK_NK_MN)
{
success = test(Row{}, Col{}, Row{});
}
else if(args.layout == GemmMatrixLayout::KM_KN_MN)
{
success = test(Col{}, Row{}, Row{});
}
else
{
success = test(Col{}, Col{}, Row{});
}
auto error_code = 0;
if(success)
{
std::cout << "test split k : Pass" << std::endl;
}
else
{
std::cout << "test split k: Fail " << std::endl;
error_code = -1; // test needs to report failure
}
return error_code;
}
int main(int argc, char* argv[])
{
std::vector<gemmArgs> test_cases;
if(argc == 1)
{
test_cases = {{GemmMatrixLayout::MK_KN_MN, 1024, 1024, 1024, 1024, 1024, 1024, 2},
{GemmMatrixLayout::MK_KN_MN, 1024, 1024, 1024, 1024, 1024, 1024, 8}};
}
else if(argc == 9)
{
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
const int M = std::stoi(argv[2]);
const int N = std::stoi(argv[3]);
const int K = std::stoi(argv[4]);
const int StrideA = std::stoi(argv[5]);
const int StrideB = std::stoi(argv[6]);
const int StrideC = std::stoi(argv[7]);
const int KBatch = std::stoi(argv[8]);
test_cases = {{layout, M, N, K, StrideA, StrideB, StrideC, KBatch}};
}
else
{
printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n");
return -1;
}
bool error = false;
for(const auto& kinder : test_cases)
{
error |= test_gemm(kinder);
}
return error ? 1 : 0;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "test_gemm_splitk_util.hpp"
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
namespace {
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
{
using type = std::tuple<Xs..., Ys...>;
};
} // namespace
template <typename Tuple>
class TestGemmSplitK_MK_KN
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmSplitK_MK_NK
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmSplitK_KM_KN
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmSplitK_KM_NK
: public ck::test::TestGemmSplitK<typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
{
};
// clang-format off
using KernelTypes = ::testing::Types<
// ADataType, BDataType, CDataType
std::tuple< F16, F16, F16>,
std::tuple< F32, F32, F32>
>;
// clang-format on
TYPED_TEST_SUITE(TestGemmSplitK_MK_KN, KernelTypes);
TYPED_TEST_SUITE(TestGemmSplitK_MK_NK, KernelTypes);
TYPED_TEST_SUITE(TestGemmSplitK_KM_KN, KernelTypes);
TYPED_TEST_SUITE(TestGemmSplitK_KM_NK, KernelTypes);
#include "test_gemm_splitk_ut_cases.inc"
#pragma once
TYPED_TEST(TestGemmSplitK_MK_KN, SmallM)
{
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideA = K;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_MK_NK, SmallM)
{
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_KN, SmallM)
{
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_NK, SmallM)
{
std::vector<int> Ms{0, 1, 2, 3, 4, 5, 6};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_MK_KN, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideA = K;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_MK_NK, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_KN, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_NK, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 512;
constexpr int K = 320;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_MK_KN, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 512;
constexpr int K = 437;
constexpr int StrideA = K;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_MK_NK, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 512;
constexpr int K = 437;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_KN, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 512;
constexpr int K = 437;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_NK, PaddK)
{
std::vector<int> Ms{127};
constexpr int N = 512;
constexpr int K = 437;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_MK_KN, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 512;
constexpr int StrideA = K;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_MK_NK, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 512;
constexpr int StrideA = K;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_KN, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 512;
constexpr int StrideB = N;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
TYPED_TEST(TestGemmSplitK_KM_NK, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 512;
constexpr int StrideB = K;
constexpr int StrideC = N;
for(int M : Ms)
this->Run(M, N, K, M, StrideB, StrideC);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "include/ck/utility/data_type.hpp"
#include "profiler/profile_gemm_splitk_impl.hpp"
namespace ck {
namespace test {
template <typename Tuple>
class TestGemmSplitK : public testing::Test
{
using Row = ck::tensor_layout::gemm::RowMajor;
using F32 = float;
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = Row;
using ADataType = std::tuple_element_t<2, Tuple>;
using BDataType = std::tuple_element_t<3, Tuple>;
using CDataType = std::tuple_element_t<4, Tuple>;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // decimal value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
std::vector<int> k_batches_;
void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; }
void Run(const int M,
const int N,
const int K,
const int StrideA,
const int StrideB,
const int StrideC)
{
for(auto kb : k_batches_)
{
RunSingle(M, N, K, StrideA, StrideB, StrideC, kb);
}
}
void RunSingle(const int M,
const int N,
const int K,
const int StrideA,
const int StrideB,
const int StrideC,
int kbatch = 1)
{
bool pass = ck::profiler::profile_gemm_splitk_impl<ADataType,
BDataType,
F32,
CDataType,
ALayout,
BLayout,
CLayout>(
verify_, init_method_, log_, bench_, M, N, K, StrideA, StrideB, StrideC, kbatch);
EXPECT_TRUE(pass);
}
};
} // namespace test
} // namespace ck
add_gtest_executable(test_grouped_convnd_bwd_weight grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_gtest_executable(test_grouped_convnd_bwd_weight grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
endif()
\ No newline at end of file
add_test_executable(test_grouped_gemm_fp16 grouped_gemm_fp16.cpp)
target_link_libraries(test_grouped_gemm_fp16 PRIVATE utility)
target_link_libraries(test_grouped_gemm_fp16 PRIVATE device_grouped_gemm_instance)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
add_custom_target(test_grouped_gemm)
add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp)
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface.cpp)
target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance)
target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk test_grouped_gemm_interface)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <random>
#include "profiler/profile_grouped_gemm_impl.hpp"
namespace {
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename ALayout, typename BLayout, typename CLayout>
bool TestGroupedGemm()
{
std::mt19937 gen(19391);
std::uniform_int_distribution<> distrib(1, 10);
int group_count = distrib(gen);
// GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a, p_b;
std::vector<void*> p_c;
std::vector<int> Ms, Ns, Ks, StrideAs, StrideBs, StrideCs;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * distrib(gen));
Ns.push_back(256 + 256 * distrib(gen));
Ks.push_back(128 + 128 * distrib(gen));
StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
StrideCs.push_back(std::is_same<Row, CLayout>::value ? Ns[i] : Ms[i]);
}
return ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType,
CDataType,
AccDataType,
ALayout,
BLayout,
CLayout>(
true, 1, false, 1, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs);
}
} // anonymous namespace
int main()
{
bool res = true;
res = res && TestGroupedGemm<Row, Row, Row>();
res = res && TestGroupedGemm<Row, Col, Row>();
res = res && TestGroupedGemm<Col, Row, Row>();
res = res && TestGroupedGemm<Col, Col, Row>();
std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment