Commit 3eee1b9b authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

adding tall and skinny gemm

parent 67adf1b4
......@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemv_splitk.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
......@@ -25,7 +25,7 @@ static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpeciali
#define B 64 // block-size:64
// clang-format off
using DeviceGemvInstance = ck::tensor_operation::device::deviceGemvDl/*
using DeviceGemvInstance = ck::tensor_operation::device::deviceTsmmDl/*
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
......
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_tall_and_skinny_gemm_splitk)
add_example_executable(example_tall_and_skinny_gemm_splitk_fp16 tall_and_skinny_gemm_splitk_fp16.cpp)
# set_source_files_properties(splitK_gemv_fp16.cpp PROPERTIES COMPILE_OPTIONS "--save-temps;-Wno-gnu-line-marker;-gline-tables-only")
add_dependencies(example_tall_and_skinny_gemm_splitk
example_tall_and_skinny_gemm_splitk_fp16)
set(target 1)
endif()
endforeach()
\ No newline at end of file
# Instructions for ```example_gemv_splitk```
## Run ```example_gemv_splitk```
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
#arg4: number of splitk batches
./bin/example_tall_and_skinny_gemm_splitk_fp* 0 1 5 151
```
Result (MI250 @ 800Mhz, 181.05TFlops peak FP16)
```
a_m_k: dim 2, lengths {16, 1024}, strides {1024, 1}
b_k_n: dim 2, lengths {1024, 16}, strides {16, 1}
c_m_n: dim 2, lengths {16, 16}, strides {16, 1}
Perf: 0.0684798 ms, 0.0076561 TFlops, 0.964489 GB/s, deviceGemvDl<64, 16, 128, 4, 2, 16, 2, 1>
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
struct ProblemSize final // Default GEMV problem size
{
ck::index_t M = 16;
ck::index_t N = 16;
ck::index_t K = 1024;
// ck::index_t M = 2;
// ck::index_t N = 256;
// ck::index_t K = 256;
ck::index_t stride_A = K;
ck::index_t stride_B = N;//K;
ck::index_t stride_C = N;
ck::index_t k_batch = 1;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
inline bool
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
{
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
}
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[7]);
problem_size.stride_A = std::stoi(argv[8]);
problem_size.stride_B = std::stoi(argv[9]);
problem_size.stride_C = std::stoi(argv[10]);
}
else
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl;
return false;
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool run_tall_and_skinny_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC, k_batch] = problem_size; // //
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
#ifdef BUILD_INT4_EXAMPLE
DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
c_m_n_device_result.mDesc.GetElementSpaceSize());
const Tensor<KernelADataType> a_m_k_converted(a_m_k);
const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
#else
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
#endif
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto tsmm = DeviceTSMMInstance{};
auto invoker = tsmm.MakeInvoker();
auto argument = tsmm.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
k_batch); // //
// //
if(!tsmm.IsSupportedArgument(argument))
{
std::cerr << tsmm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
c_m_n_device_buf.SetZero();
if(config.do_verification)
{
invoker.Run(argument, StreamConfig{nullptr, false}); // Run prior to verification
auto ref_tsmm = ReferenceGemmInstance{};
auto ref_invoker = ref_tsmm.MakeInvoker();
auto ref_argument = ref_tsmm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
#ifdef BUILD_INT4_EXAMPLE
Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);
c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
#else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
#endif
}
float ave_time = invoker.Run(
argument, StreamConfig{nullptr, config.time_kernel}); // Run to measure performance
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< tsmm.GetTypeString() << std::endl;
#ifdef BUILD_INT4_EXAMPLE
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#endif
}
bool run_tall_and_skinny_gemm_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
}
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[7]);
problem_size.stride_A = std::stoi(argv[8]);
problem_size.stride_B = std::stoi(argv[9]);
problem_size.stride_C = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4: splitk\n");
printf("arg5 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
return run_tall_and_skinny_gemm(problem_size, config);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Row;
using BLayout = Row;//Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
#define K1 2
#define K0 4
#define N1 2
#define B 64 // block-size:64
#define M1 16
// clang-format off
using DeviceTSMMInstance = ck::tensor_operation::device::deviceTsmmDl/*
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 64, 32, 2, 1, 1, 1, S<1, 1, 1, 2>, S<32, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 1>;*/
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, M1, B*N1, K0, K1, M1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1,M1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, 3, N1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_tall_and_skinny_gemm_splitk_example.inc"
int main(int argc, char* argv[]) { return !run_tall_and_skinny_gemm_example(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment