Commit aa5d4037 authored by Adam Osewski's avatar Adam Osewski
Browse files

Add example for new kernel.

parent 2fb0521d
...@@ -20,11 +20,13 @@ add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) ...@@ -20,11 +20,13 @@ add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
add_example_executable(example_gemm_xdl_direct_c_write_out_fp16 gemm_xdl_direct_c_write_out_fp16.cpp)
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16) add_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
add_dependencies(example_gemm_xdl example_gemm_xdl_bf16) add_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
add_dependencies(example_gemm_xdl example_gemm_xdl_int8) add_dependencies(example_gemm_xdl example_gemm_xdl_int8)
add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_dependencies(example_gemm_xdl example_gemm_xdl_direct_c_write_out_fp16)
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp) add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp)
......
...@@ -38,6 +38,7 @@ struct ExecutionConfig final ...@@ -38,6 +38,7 @@ struct ExecutionConfig final
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = false;
bool do_log = false;
}; };
template <ck::index_t... Is> template <ck::index_t... Is>
...@@ -55,33 +56,36 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi ...@@ -55,33 +56,36 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi
{ {
// use default case // use default case
} }
else if(argc == 4) else if(argc == 5)
{ {
config.do_verification = std::stoi(argv[1]); config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]); config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]); config.time_kernel = std::stoi(argv[3]);
config.do_log = std::stoi(argv[4]);
} }
else if(argc == 10) else if(argc == 11)
{ {
config.do_verification = std::stoi(argv[1]); config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]); config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]); config.time_kernel = std::stoi(argv[3]);
config.do_log = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[4]); problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[5]); problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[6]); problem_size.K = std::stoi(argv[7]);
problem_size.StrideA = std::stoi(argv[7]); problem_size.StrideA = std::stoi(argv[8]);
problem_size.StrideB = std::stoi(argv[8]); problem_size.StrideB = std::stoi(argv[9]);
problem_size.StrideC = std::stoi(argv[9]); problem_size.StrideC = std::stoi(argv[10]);
} }
else else
{ {
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl std::cerr << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< std::endl << "arg3: time kernel (0=no, 1=yes)\n"
<< "arg3: time kernel (0=no, 1=yes)" << std::endl << "arg4: print tensor (0=no, 1=yes)\n"
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl; << "arg5 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC"
<< std::endl;
return false; return false;
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_direct_c_write_out.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = float;
using CDataType = ck::half_t;
using F16 = ck::half_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto LoopSchedDefault = ck::LoopScheduler::Default;
static constexpr auto GemmPipeline = ck::PipelineVersion::v1;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_DirectCWriteOut
// clang-format off
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| LoopScheduler| PipelineVersion|
// ######| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| | |
// ######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, LoopSchedDefault, GemmPipeline>;
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, LoopSchedDefault, GemmPipeline>;
// clang-format on
// clang-format off
using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| PipelineVersion|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | |
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopSchedDefault, GemmPipeline>;
// clang-format on
using DeviceGemmInstance = DeviceGemmInstance;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
...@@ -30,10 +30,31 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -30,10 +30,31 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
switch(config.init_method) switch(config.init_method)
{ {
case 0: break; case 0:
ck::utils::FillConstant<ADataType>{1.f}(a_m_k);
ck::utils::FillConstant<BDataType>{0.f}(b_k_n);
// for (ck::index_t m = 0; m < M; ++m)
// {
// for (ck::index_t k = 0; k < K; ++k)
// {
// a_m_k(m, k) = (m * M + k) % 5;
// }
// }
for(ck::index_t n = 0; n < N; ++n)
{
for(ck::index_t k = 0; k < K; ++k)
{
if(n == k)
b_k_n(k, n) = n * 2;
}
}
break;
case 1: case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{-1.f, 3.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-1.f, 3.f}(b_k_n);
break; break;
default: default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k); ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
...@@ -65,6 +86,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -65,6 +86,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data());
c_m_n_device_buf.SetZero();
#endif #endif
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
...@@ -114,6 +136,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -114,6 +136,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl; << gemm.GetTypeString() << std::endl;
bool result = true;
if(config.do_verification) if(config.do_verification)
{ {
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
...@@ -131,15 +154,25 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -131,15 +154,25 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>(); c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); result = result && ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else #else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); result = result && ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#endif #endif
} }
return true; if(config.do_log)
{
LogRangeAsType<float>(std::cout << "a:\n", a_m_k.mData, ",", 32) << std::endl;
LogRangeAsType<float>(std::cout << "b:\n", b_k_n.mData, ",", 32) << std::endl;
LogRangeAsType<float>(std::cout << "c_host:\n", c_m_n_host_result.mData, ",", 32)
<< std::endl;
LogRangeAsType<float>(std::cout << "c_device:\n", c_m_n_device_result.mData, ",", 32)
<< std::endl;
}
return result;
} }
bool run_gemm_example(int argc, char* argv[]) bool run_gemm_example(int argc, char* argv[])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment