"vscode:/vscode.git/clone" did not exist on "9e0cd251f3bb5d0dd489d20f92fe2e3bb436e588"
Unverified Commit 3ab20fd7 authored by Adam Osewski's avatar Adam Osewski Committed by GitHub
Browse files

GEMM batched/splitK/cgemm/grouped int4 examples (#383)



* Grouped GEmm int4.

* Formatting + fix K dimension for int8.

* Batched Gemm int4 example.

* CGEMM int4 example.

* Include inc filese in clang-format.

* SplitK int4 example

* Refactoring of performance measurement.

* Fix #ifdef statements.
Co-authored-by: default avatarAdam Osewski <aosewski@amd.com>
parent b73ae242
add_custom_target(example_grouped_gemm_xdl)
add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp)
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp) add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp)
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
add_dependencies(example_grouped_gemm_xdl
example_grouped_gemm_xdl_fp32
example_grouped_gemm_xdl_fp16
example_grouped_gemm_xdl_bfp16
example_grouped_gemm_xdl_int8)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::int4_t;
using BDataType = ck::int4_t;
using AccDataType = int32_t;
using CShuffleDataType = int32_t;
using DsDataType = ck::Tuple<>;
using EDataType = ck::int4_t;
using KernelADataType = int8_t;
using KernelBDataType = int8_t;
using KernelEDataType = int8_t;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
// clang-format off
< ALayout, //ALayout
BLayout, //BLayout
DsLayout, //DsLayout
ELayout, //ELayout
KernelADataType, //ADataType
KernelBDataType, //BDataType
AccDataType, //AccDataType
CShuffleDataType, //CShuffleDataType
DsDataType, //DsDataType
KernelEDataType, //EDataType
AElementOp, //AElementwiseOperation
BElementOp, //BElementwiseOperation
CDEElementOp, //CDEElementwiseOperation
GemmDefault, //GEMMSpecialization
1, // NumGemmKPrefetchStage
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
64, // KPerBlock
16, // AK1
16, // BK1
32, // MPerXdl
32, // NPerXdl
4, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder
2, // ABlockTransfer SrcVectorDim
16, // ABlockTransfer SrcScalarPerVector
16, // ABlockTransfer DstScalarPerVector_K1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1
S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // BBlockTransfer SrcAccessOrder
2, // BBlockTransfer SrcVectorDim
16, // BBlockTransfer SrcScalarPerVector
16, // BBlockTransfer DstScalarPerVector_K1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl
16>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
#define BUILD_INT4_EXAMPLE
#include "run_grouped_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
...@@ -22,6 +22,12 @@ struct ExecutionConfig final ...@@ -22,6 +22,12 @@ struct ExecutionConfig final
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{ {
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
static_assert(sizeof(ADataType) == sizeof(KernelADataType));
static_assert(sizeof(BDataType) == sizeof(KernelBDataType));
static_assert(sizeof(EDataType) == sizeof(KernelEDataType));
#endif
int group_count = problem_size.group_count; int group_count = problem_size.group_count;
// GEMM shape // GEMM shape
...@@ -61,7 +67,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -61,7 +67,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
std::vector<Tensor<ADataType>> a_tensors; std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors; std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<EDataType>> c_host_tensors; std::vector<Tensor<EDataType>> c_host_tensors;
#ifdef BUILD_INT4_EXAMPLE
std::vector<Tensor<KernelEDataType>> c_device_tensors;
#else
std::vector<Tensor<EDataType>> c_device_tensors; std::vector<Tensor<EDataType>> c_device_tensors;
#endif
a_tensors.reserve(group_count); a_tensors.reserve(group_count);
b_tensors.reserve(group_count); b_tensors.reserve(group_count);
...@@ -86,9 +96,13 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -86,9 +96,13 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{}))); gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{})));
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor( c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{}))); gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{})));
#ifdef BUILD_INT4_EXAMPLE
c_device_tensors.push_back(Tensor<KernelEDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{})));
#else
c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor( c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{}))); gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{})));
#endif
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl; << std::endl;
...@@ -124,8 +138,16 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -124,8 +138,16 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
c_tensors_device.emplace_back(std::make_unique<DeviceMem>( c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSpaceSize())); sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSpaceSize()));
#ifdef BUILD_INT4_EXAMPLE
const Tensor<KernelADataType> a_converted(a_tensors[i]);
const Tensor<KernelBDataType> b_converted(b_tensors[i]);
a_tensors_device[i]->ToDevice(a_converted.mData.data());
b_tensors_device[i]->ToDevice(b_converted.mData.data());
#else
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
#endif
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b.push_back(b_tensors_device[i]->GetDeviceBuffer()); p_b.push_back(b_tensors_device[i]->GetDeviceBuffer());
...@@ -156,14 +178,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -156,14 +178,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); invoker.Run(argument, StreamConfig{nullptr, false});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
bool pass = true; bool pass = true;
if(config.do_verification) if(config.do_verification)
...@@ -190,11 +205,28 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -190,11 +205,28 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
c_element_op); c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
#ifdef BUILD_INT4_EXAMPLE
const Tensor<EDataType> c_device_result_converted(c_device_tensors[i]);
pass &= ck::utils::check_err(c_device_result_converted.mData, c_host_tensors[i].mData);
#else
pass &= ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData); pass &= ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData);
#endif
}
} }
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
} }
return pass ? 0 : 1; return pass;
} }
bool run_grouped_gemm_example(int argc, char* argv[]) bool run_grouped_gemm_example(int argc, char* argv[])
...@@ -208,7 +240,7 @@ bool run_grouped_gemm_example(int argc, char* argv[]) ...@@ -208,7 +240,7 @@ bool run_grouped_gemm_example(int argc, char* argv[])
{ {
problem_size.Ms.push_back(256 + 256 * i); problem_size.Ms.push_back(256 + 256 * i);
problem_size.Ns.push_back(128 + 128 * i); problem_size.Ns.push_back(128 + 128 * i);
problem_size.Ks.push_back(64 + 64 * i); problem_size.Ks.push_back(128 + 64 * i);
problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]);
......
...@@ -5,7 +5,13 @@ add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) ...@@ -5,7 +5,13 @@ add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp)
add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp) add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp)
add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16) add_dependencies(example_cgemm_xdl
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16) example_cgemm_xdl_bf16
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32) example_cgemm_xdl_fp16
add_dependencies(example_cgemm_xdl example_cgemm_xdl_int8) example_cgemm_xdl_fp32
example_cgemm_xdl_int8)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp)
add_dependencies(example_cgemm_xdl example_cgemm_xdl_int4)
endif()
...@@ -117,7 +117,7 @@ int main(int argc, char* argv[]) ...@@ -117,7 +117,7 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
return run_cgemm_xdl<ADataType, return !run_cgemm_xdl<ADataType,
BDataType, BDataType,
CDataType, CDataType,
ALayout, ALayout,
......
...@@ -21,6 +21,9 @@ using F32 = float; ...@@ -21,6 +21,9 @@ using F32 = float;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using INT8 = std::int8_t; using INT8 = std::int8_t;
using INT32 = std::int32_t; using INT32 = std::int32_t;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using INT4 = ck::int4_t;
#endif
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
...@@ -32,8 +35,11 @@ template <typename ADataType, ...@@ -32,8 +35,11 @@ template <typename ADataType,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename DeviceCGemmInstance, typename DeviceCGemmInstance,
typename ReferenceCGemmInstance> typename ReferenceCGemmInstance,
int run_cgemm_xdl(ck::index_t M, typename KernelADataType = ADataType,
typename KernelBDataType = BDataType,
typename KernelCDataType = CDataType>
bool run_cgemm_xdl(ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t StrideA, ck::index_t StrideA,
...@@ -43,6 +49,17 @@ int run_cgemm_xdl(ck::index_t M, ...@@ -43,6 +49,17 @@ int run_cgemm_xdl(ck::index_t M,
int init_method, int init_method,
bool time_kernel) bool time_kernel)
{ {
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static_assert(sizeof(ck::int4_t) == sizeof(int8_t),
"sizeof ck::int4_t and int8_t is different!");
static_assert(sizeof(ADataType) == sizeof(KernelADataType),
"sizeof ADataType and KernelADataType is different!");
static_assert(sizeof(BDataType) == sizeof(KernelBDataType),
"sizeof BDataType and KernelBDataType is different!");
static_assert(sizeof(CDataType) == sizeof(KernelCDataType),
"sizeof CDataType and KernelCDataType is different!");
#endif
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value) if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
...@@ -61,8 +78,10 @@ int run_cgemm_xdl(ck::index_t M, ...@@ -61,8 +78,10 @@ int run_cgemm_xdl(ck::index_t M,
Tensor<ADataType> a_m_k_imag(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k_imag(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n_real(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n_real(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_imag(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n_imag(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_real_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<KernelCDataType> c_m_n_real_device_result(
Tensor<CDataType> c_m_n_imag_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<KernelCDataType> c_m_n_imag_device_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k_real: " << a_m_k_real.mDesc << std::endl; std::cout << "a_m_k_real: " << a_m_k_real.mDesc << std::endl;
std::cout << "a_m_k_imag: " << a_m_k_imag.mDesc << std::endl; std::cout << "a_m_k_imag: " << a_m_k_imag.mDesc << std::endl;
...@@ -89,20 +108,41 @@ int run_cgemm_xdl(ck::index_t M, ...@@ -89,20 +108,41 @@ int run_cgemm_xdl(ck::index_t M,
auto cgemm = DeviceCGemmInstance{}; auto cgemm = DeviceCGemmInstance{};
DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * a_m_k_real.mDesc.GetElementSpaceSize()); DeviceMem a_m_k_real_device_buf(sizeof(KernelADataType) *
DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * a_m_k_imag.mDesc.GetElementSpaceSize()); a_m_k_real.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * b_k_n_real.mDesc.GetElementSpaceSize()); DeviceMem a_m_k_imag_device_buf(sizeof(KernelADataType) *
DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * b_k_n_imag.mDesc.GetElementSpaceSize()); a_m_k_imag.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * DeviceMem b_k_n_real_device_buf(sizeof(KernelBDataType) *
b_k_n_real.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_imag_device_buf(sizeof(KernelBDataType) *
b_k_n_imag.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_real_device_buf(sizeof(KernelCDataType) *
c_m_n_real_device_result.mDesc.GetElementSpaceSize()); c_m_n_real_device_result.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * DeviceMem c_m_n_imag_device_buf(sizeof(KernelCDataType) *
c_m_n_imag_device_result.mDesc.GetElementSpaceSize()); c_m_n_imag_device_result.mDesc.GetElementSpaceSize());
DeviceMem workspace_device_buf(cgemm.GetWorkspaceSize(M, N, K, StrideA, StrideB, StrideC)); DeviceMem workspace_device_buf(cgemm.GetWorkspaceSize(M, N, K, StrideA, StrideB, StrideC));
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
if constexpr(std::is_same_v<ADataType, ck::int4_t>)
{
Tensor<KernelADataType> a_m_k_real_converted(a_m_k_real);
Tensor<KernelADataType> a_m_k_imag_converted(a_m_k_imag);
Tensor<KernelBDataType> b_k_n_real_converted(b_k_n_real);
Tensor<KernelBDataType> b_k_n_imag_converted(b_k_n_imag);
a_m_k_real_device_buf.ToDevice(a_m_k_real_converted.mData.data());
a_m_k_imag_device_buf.ToDevice(a_m_k_imag_converted.mData.data());
b_k_n_real_device_buf.ToDevice(b_k_n_real_converted.mData.data());
b_k_n_imag_device_buf.ToDevice(b_k_n_imag_converted.mData.data());
}
else
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
{
a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data()); a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data());
a_m_k_imag_device_buf.ToDevice(a_m_k_imag.mData.data()); a_m_k_imag_device_buf.ToDevice(a_m_k_imag.mData.data());
b_k_n_real_device_buf.ToDevice(b_k_n_real.mData.data()); b_k_n_real_device_buf.ToDevice(b_k_n_real.mData.data());
b_k_n_imag_device_buf.ToDevice(b_k_n_imag.mData.data()); b_k_n_imag_device_buf.ToDevice(b_k_n_imag.mData.data());
}
auto a_element_op = AElementwiseOperation{}; auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{}; auto b_element_op = BElementwiseOperation{};
...@@ -111,13 +151,13 @@ int run_cgemm_xdl(ck::index_t M, ...@@ -111,13 +151,13 @@ int run_cgemm_xdl(ck::index_t M,
// do GEMM // do GEMM
auto invoker = cgemm.MakeInvoker(); auto invoker = cgemm.MakeInvoker();
auto argument = auto argument =
cgemm.MakeArgument(static_cast<ADataType*>(a_m_k_real_device_buf.GetDeviceBuffer()), cgemm.MakeArgument(static_cast<KernelADataType*>(a_m_k_real_device_buf.GetDeviceBuffer()),
static_cast<ADataType*>(a_m_k_imag_device_buf.GetDeviceBuffer()), static_cast<KernelADataType*>(a_m_k_imag_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_real_device_buf.GetDeviceBuffer()), static_cast<KernelBDataType*>(b_k_n_real_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_imag_device_buf.GetDeviceBuffer()), static_cast<KernelBDataType*>(b_k_n_imag_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()), static_cast<KernelCDataType*>(c_m_n_real_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_imag_device_buf.GetDeviceBuffer()), static_cast<KernelCDataType*>(c_m_n_imag_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(workspace_device_buf.GetDeviceBuffer()), static_cast<KernelCDataType*>(workspace_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
...@@ -143,15 +183,11 @@ int run_cgemm_xdl(ck::index_t M, ...@@ -143,15 +183,11 @@ int run_cgemm_xdl(ck::index_t M,
(sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N); (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< cgemm.GetTypeString() << std::endl; << cgemm.GetTypeString() << std::endl;
c_m_n_real_device_buf.FromDevice(c_m_n_real_device_result.mData.data());
c_m_n_imag_device_buf.FromDevice(c_m_n_imag_device_result.mData.data());
if(do_verification) if(do_verification)
{ {
Tensor<CDataType> c_m_n_real_host_result( Tensor<CDataType> c_m_n_real_host_result(
...@@ -161,7 +197,6 @@ int run_cgemm_xdl(ck::index_t M, ...@@ -161,7 +197,6 @@ int run_cgemm_xdl(ck::index_t M,
auto ref_cgemm = ReferenceCGemmInstance{}; auto ref_cgemm = ReferenceCGemmInstance{};
auto ref_invoker = ref_cgemm.MakeInvoker(); auto ref_invoker = ref_cgemm.MakeInvoker();
auto ref_argument = ref_cgemm.MakeArgument(a_m_k_real, auto ref_argument = ref_cgemm.MakeArgument(a_m_k_real,
a_m_k_imag, a_m_k_imag,
b_k_n_real, b_k_n_real,
...@@ -174,19 +209,45 @@ int run_cgemm_xdl(ck::index_t M, ...@@ -174,19 +209,45 @@ int run_cgemm_xdl(ck::index_t M,
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
c_m_n_real_device_buf.FromDevice(c_m_n_real_device_result.mData.data());
c_m_n_imag_device_buf.FromDevice(c_m_n_imag_device_result.mData.data());
bool result = true; bool result = true;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
if constexpr(std::is_same_v<ADataType, ck::int4_t>)
{
const Tensor<CDataType> c_m_n_real_device_result_converted(c_m_n_real_device_result);
const Tensor<CDataType> c_m_n_imag_device_result_converted(c_m_n_imag_device_result);
result = ck::utils::check_err(c_m_n_real_device_result_converted.mData,
c_m_n_real_host_result.mData,
"Verification error: incorrect results in real part!",
1e-2f,
1e-1f);
result = result && ck::utils::check_err(
c_m_n_imag_device_result_converted.mData,
c_m_n_imag_host_result.mData,
"Verification error: incorrect results in imaginary part!",
1e-2f,
1e-1f);
}
else
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
{
result = ck::utils::check_err(c_m_n_real_device_result.mData, result = ck::utils::check_err(c_m_n_real_device_result.mData,
c_m_n_real_host_result.mData, c_m_n_real_host_result.mData,
"Verification error: incorrect results in real part!", "Verification error: incorrect results in real part!",
1e-2f, 1e-2f,
1e-1f); 1e-1f);
result = result && result = result && ck::utils::check_err(
ck::utils::check_err(c_m_n_imag_device_result.mData, c_m_n_imag_device_result.mData,
c_m_n_imag_host_result.mData, c_m_n_imag_host_result.mData,
"Verification error: incorrect results in imaginary part!", "Verification error: incorrect results in imaginary part!",
1e-2f, 1e-2f,
1e-1f); 1e-1f);
return result ? 0 : 1;
} }
return 0;
return result;
}
return true;
} }
...@@ -116,7 +116,7 @@ int main(int argc, char* argv[]) ...@@ -116,7 +116,7 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
return run_cgemm_xdl<ADataType, return !run_cgemm_xdl<ADataType,
BDataType, BDataType,
CDataType, CDataType,
ALayout, ALayout,
......
...@@ -117,7 +117,7 @@ int main(int argc, char* argv[]) ...@@ -117,7 +117,7 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
return run_cgemm_xdl<ADataType, return !run_cgemm_xdl<ADataType,
BDataType, BDataType,
CDataType, CDataType,
ALayout, ALayout,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "cgemm_xdl_common.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
using ADataType = INT4;
using BDataType = INT4;
using CDataType = INT4;
using AccDataType = INT32;
using CShuffleDataType = INT32;
using KernelADataType = INT8;
using KernelBDataType = INT8;
using KernelCDataType = INT8;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using ReferenceCGemmInstance = ck::tensor_operation::host::
ReferenceCGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
// clang-format off
using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_CShuffle
<ALayout, // typename ALayout
BLayout, // typename BLayout
CLayout, // typename CLayout
KernelADataType, // typename ADataType
KernelBDataType, // typename BDataType
KernelCDataType, // typename CDataType
AccDataType, // typename GemmAccDataType
CShuffleDataType, // typename CShuffleDataType
PassThrough, // typename AElementwiseOperation
PassThrough, // typename BElementwiseOperation
PassThrough, // typename CElementwiseOperation
GemmDefault, // GemmSpecialization GemmSpec
1, // index_t NumGemmKPrefetchStage
256, // index_t BlockSize
256, // index_t MPerBlock
128, // index_t NPerBlock
64, // index_t KPerBlock
16, // index_t AK1
16, // index_t BK1
32, // index_t MPerXDL
32, // index_t NPerXDL
4, // index_t MXdlPerWave
2, // index_t NXdlPerWave
S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder
2, // index_t ABlockTransferSrcVectorDim
16, // index_t ABlockTransferSrcScalarPerVector
16, // index_t ABlockTransferDstScalarPerVector_AK1
1, // index_t ABlockLdsExtraM
S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder
2, // index_t BBlockTransferSrcVectorDim
8, // index_t BBlockTransferSrcScalarPerVector
8, // index_t BBlockTransferDstScalarPerVector_BK1
1, // index_t BBlockLdsExtraN
1, // index_t CShuffleMXdlPerWavePerShuffle
1, // index_t CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// CGEMM shape
ck::index_t M = 1024;
ck::index_t N = 1152;
ck::index_t K = 512;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideC = N;
if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideC = std::stoi(argv[9]);
}
else
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"
<< std::endl;
exit(EXIT_SUCCESS);
}
return !run_cgemm_xdl<ADataType,
BDataType,
CDataType,
ALayout,
BLayout,
CLayout,
PassThrough,
PassThrough,
PassThrough,
DeviceCGemmInstance,
ReferenceCGemmInstance,
KernelADataType,
KernelBDataType,
KernelCDataType>(
M, N, K, StrideA, StrideB, StrideC, do_verification, init_method, time_kernel);
}
...@@ -117,7 +117,7 @@ int main(int argc, char* argv[]) ...@@ -117,7 +117,7 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
return run_cgemm_xdl<ADataType, return !run_cgemm_xdl<ADataType,
BDataType, BDataType,
CDataType, CDataType,
ALayout, ALayout,
......
add_custom_target(example_batched_gemm_xdl)
add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp)
add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp) add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp)
add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp)
add_dependencies(example_batched_gemm_xdl
example_batched_gemm_xdl_fp32
example_batched_gemm_xdl_fp16
example_batched_gemm_xdl_bfp16
example_batched_gemm_xdl_int8)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp)
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4)
endif()
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::int4_t;
using BDataType = ck::int4_t;
using AccDataType = int32_t;
using CShuffleDataType = int32_t;
using DsDataType = ck::Tuple<>;
using EDataType = ck::int4_t;
using KernelADataType = int8_t;
using KernelBDataType = int8_t;
using KernelEDataType = int8_t;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl
// clang-format off
< ALayout, //ALayout
BLayout, //BLayout
DsLayout, //DsLayout
ELayout, //ELayout
KernelADataType, //ADataType
KernelBDataType, //BDataType
AccDataType, //AccDataType
CShuffleDataType, //CShuffleDataType
DsDataType, //DsDataType
KernelEDataType, //EDataType
AElementOp, //AElementwiseOperation
BElementOp, //BElementwiseOperation
CDEElementOp, //CDEElementwiseOperation
GemmDefault, //GEMMSpecialization
1, // NumGemmKPrefetchStage
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
64, // KPerBlock
16, // AK1
16, // BK1
32, // MPerXdl
32, // NPerXdl
4, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // ABlockTransfer SrcAccessOrder
2, // ABlockTransfer SrcVectorDim
16, // ABlockTransfer SrcScalarPerVector
16, // ABlockTransfer DstScalarPerVector_K1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1
S<1, 0, 2>, // BBlockTransfer ThreadCluster ArrangeOrder
S<1, 0, 2>, // BBlockTransfer SrcAccessOrder
2, // BBlockTransfer SrcVectorDim
16, // BBlockTransfer SrcScalarPerVector
16, // BBlockTransfer DstScalarPerVector_K1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl
16>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
#define BUILD_INT4_EXAMPLE
#include "run_batched_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
#include <random>
#pragma once #pragma once
struct ProblemSize final struct ProblemSize final
...@@ -28,7 +30,23 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -28,7 +30,23 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
{ {
using namespace ck::literals; using namespace ck::literals;
auto& [M, N, K, stride_A, stride_B, stride_C, batch_stride_A, batch_stride_B, batch_stride_C, batch_count] = problem_size; #if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
static_assert(sizeof(ADataType) == sizeof(KernelADataType));
static_assert(sizeof(BDataType) == sizeof(KernelBDataType));
static_assert(sizeof(EDataType) == sizeof(KernelEDataType));
#endif
auto& [M,
N,
K,
stride_A,
stride_B,
stride_C,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count] = problem_size;
// GEMM shape // GEMM shape
auto f_host_tensor_descriptor = [](std::size_t batch_count_, auto f_host_tensor_descriptor = [](std::size_t batch_count_,
...@@ -53,9 +71,13 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -53,9 +71,13 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{})); f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{}));
Tensor<BDataType> b_g_k_n( Tensor<BDataType> b_g_k_n(
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{})); f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{}));
#ifdef BUILD_INT4_EXAMPLE
Tensor<KernelEDataType> e_g_m_n_device_result(
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{}));
#else
Tensor<EDataType> e_g_m_n_device_result( Tensor<EDataType> e_g_m_n_device_result(
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{})); f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{}));
#endif
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
...@@ -78,9 +100,16 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -78,9 +100,16 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(EDataType) * e_g_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_device_buf(sizeof(EDataType) * e_g_m_n_device_result.mDesc.GetElementSpaceSize());
#ifdef BUILD_INT4_EXAMPLE
const Tensor<KernelADataType> a_g_m_k_converted(a_g_m_k);
const Tensor<KernelBDataType> b_g_k_n_converted(b_g_k_n);
a_device_buf.ToDevice(a_g_m_k_converted.mData.data());
b_device_buf.ToDevice(b_g_k_n_converted.mData.data());
#else
a_device_buf.ToDevice(a_g_m_k.mData.data()); a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data());
#endif
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{}; auto cde_element_op = CDEElementOp{};
...@@ -116,28 +145,21 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -116,28 +145,21 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); invoker.Run(argument, StreamConfig{nullptr, false});
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
std::size_t num_btype = sizeof(ADataType) * batch_count * M * K +
sizeof(BDataType) * batch_count * K * N +
sizeof(EDataType) * batch_count * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
bool pass = true; bool pass = true;
if(config.do_verification) if(config.do_verification)
{ {
c_device_buf.FromDevice(e_g_m_n_device_result.mData.data()); c_device_buf.FromDevice(e_g_m_n_device_result.mData.data());
using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: using ReferenceBatchedGemmInstance =
ReferenceBatchedGemm<ADataType, BDataType, EDataType, AccDataType, AElementOp, BElementOp, CDEElementOp>; ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker(); auto ref_invoker = ref_batched_gemm.MakeInvoker();
...@@ -150,8 +172,29 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -150,8 +172,29 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
#ifdef BUILD_INT4_EXAMPLE
const Tensor<EDataType> e_device_result_converted(e_g_m_n_device_result);
pass &= ck::utils::check_err(e_device_result_converted.mData, e_g_m_n_host_result.mData);
#else
pass = ck::utils::check_err( pass = ck::utils::check_err(
e_g_m_n_host_result.mData, e_g_m_n_device_result.mData, "Error: Incorrect results c"); e_g_m_n_device_result.mData, e_g_m_n_host_result.mData, "Error: Incorrect results c");
#endif
}
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
std::size_t num_btype = sizeof(ADataType) * batch_count * M * K +
sizeof(BDataType) * batch_count * K * N +
sizeof(EDataType) * batch_count * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
} }
return pass ? 0 : 1; return pass ? 0 : 1;
...@@ -162,9 +205,12 @@ bool run_batched_gemm_example(int argc, char* argv[]) ...@@ -162,9 +205,12 @@ bool run_batched_gemm_example(int argc, char* argv[])
ProblemSize problem_size; ProblemSize problem_size;
ExecutionConfig config; ExecutionConfig config;
problem_size.M = 256 * (rand() % 16 + 1); std::mt19937 gen(11939);
problem_size.N = 128 * (rand() % 16 + 1); std::uniform_int_distribution<int> dis(0, 15);
problem_size.K = 64 * (rand() % 16 + 1);
problem_size.M = 256 * (dis(gen) + 1);
problem_size.N = 128 * (dis(gen) + 1);
problem_size.K = 64 * (dis(gen) + 2);
problem_size.stride_A = problem_size.K; problem_size.stride_A = problem_size.K;
problem_size.stride_B = problem_size.K; problem_size.stride_B = problem_size.K;
......
add_custom_target(example_splitK_gemm_xdl)
add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp)
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp) add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp)
add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
add_dependencies(example_splitK_gemm_xdl
example_splitK_gemm_xdl_fp32
example_splitK_gemm_xdl_fp16
example_splitK_gemm_xdl_bfp16
example_splitK_gemm_xdl_int8)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp)
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4)
endif()
...@@ -24,6 +24,12 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -24,6 +24,12 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
{ {
using namespace ck::literals; using namespace ck::literals;
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
static_assert(sizeof(ADataType) == sizeof(KernelADataType));
static_assert(sizeof(BDataType) == sizeof(KernelBDataType));
#endif
auto& [M, N, K, StrideA, StrideB, StrideC, KBatch] = problem_size; auto& [M, N, K, StrideA, StrideB, StrideC, KBatch] = problem_size;
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
...@@ -42,12 +48,11 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -42,12 +48,11 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
switch(config.init_method) switch(config.init_method)
{ {
...@@ -69,8 +74,16 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -69,8 +74,16 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
#ifdef BUILD_INT4_EXAMPLE
const Tensor<KernelADataType> a_m_k_converted(a_m_k);
const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
#else
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data());
#endif
c_m_n_device_buf.SetZero(); c_m_n_device_buf.SetZero();
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
...@@ -80,8 +93,14 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -80,8 +93,14 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
#else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
#endif
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M, M,
N, N,
...@@ -101,23 +120,12 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -101,23 +120,12 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
return 0; return 0;
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); invoker.Run(argument, StreamConfig{nullptr, false});
bool pass = true;
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
if(config.do_verification) if(config.do_verification)
{ {
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
...@@ -129,6 +137,8 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -129,6 +137,8 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
...@@ -136,7 +146,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -136,7 +146,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
if(std::is_same<CDataType, ck::half_t>::value) if(std::is_same<CDataType, ck::half_t>::value)
{ {
return ck::utils::check_err(c_m_n_device_result.mData, pass &= ck::utils::check_err(c_m_n_device_result.mData,
c_m_n_host_result.mData, c_m_n_host_result.mData,
"fp16 incorrect result", "fp16 incorrect result",
3e-3, 3e-3,
...@@ -144,11 +154,25 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -144,11 +154,25 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
} }
else else
{ {
return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); pass &= ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
}
} }
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
} }
return true; return pass;
} }
bool run_splitK_gemm_example(int argc, char* argv[]) bool run_splitK_gemm_example(int argc, char* argv[])
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/literals.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::int4_t;
using BDataType = ck::int4_t;
using AccDataType = int32_t;
using CDataType = int32_t;
using KernelADataType = int8_t;
using KernelBDataType = int8_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle
// clang-format off
<KernelADataType, //ADataType
KernelBDataType, //BDataType
CDataType, //EDataType
AccDataType, //AccDataType
ALayout, //ALayout
BLayout, //BLayout
CLayout, //ELayout
AElementOp, //AElementwiseOperation
BElementOp, //BElementwiseOperation
CElementOp, //CElementwiseOperation
GemmDefault, //GEMMSpecialization
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
4, // KPerBlock
16, // K1
32, // MPerXdl
32, // NPerXdl
4, // MXdlPerWave
2, // NXdlPerWave
S<1, 4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1
S<0, 2, 1, 3>, // ABlockTransfer ThreadCluster ArrangeOrder
S<0, 2, 1, 3>, // ABlockTransfer SrcAccessOrder
3, // ABlockTransfer SrcVectorDim
16, // ABlockTransfer SrcScalarPerVector
16, // ABlockTransfer DstScalarPerVector_K1
true, // ABlockLdsExtraM
S<1, 4, 64, 1>, // BBlockTransfer ThreadCluster Lengths_K0_N_K1
S<0, 1, 3, 2>, // BBlockTransfer ThreadCluster ArrangeOrder
S<0, 1, 3, 2>, // BBlockTransfer SrcAccessOrder
3, // BBlockTransfer SrcVectorDim
16, // BBlockTransfer SrcScalarPerVector
16, // BBlockTransfer DstScalarPerVector_K1
true, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CBlockTransferClusterLengths _MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
4>; // CBlockTransferScalarPerVector_NWaveNPerXdl
// clang-format on
#define BUILD_INT4_EXAMPLE
#include "run_splitK_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); }
#find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' #find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment