"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "c33793103d658b5e0bd5902eb0f679526ad40155"
Commit 49facb91 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

files for gemv and tall and skinny gemm examples and corresponding entries to ckprofiler

parent 98fd41f5
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_gemv_splitk)
add_example_executable(example_gemv_splitk_fp16 gemv_splitk_fp16.cpp)
add_dependencies(example_gemv_splitk
example_gemv_splitk_fp16)
set(target 1)
endif()
endforeach()
# Instructions for ```example_gemv_splitk```
## Run ```example_gemv_splitk```
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
#arg4: number of splitk batches
bin/example_gemv_splitk_fp16 1 2 1 231
```
Result (MI250 @ 800Mhz, 181.05TFlops peak FP16)
```
a_m_k: dim 2, lengths {1, 4608}, strides {4608, 1}
b_k_n: dim 2, lengths {4608, 1104}, strides {1104, 1}
c_m_n: dim 2, lengths {1, 1104}, strides {1104, 1}
Perf: 0.0111038 ms, 0.916305 TFlops, 917.334 GB/s, deviceTsmmDl<64, 1, 128, 3, 4, 1, 2, 1>
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
struct ProblemSize final // Default GEMV problem size
{
ck::index_t M = 1;
ck::index_t N = 1104;
ck::index_t K = 4608;
ck::index_t stride_A = K;
ck::index_t stride_B = N; // K;
ck::index_t stride_C = N;
ck::index_t k_batch = 1;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
inline bool
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
{
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
}
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[7]);
problem_size.stride_A = std::stoi(argv[8]);
problem_size.stride_B = std::stoi(argv[9]);
problem_size.stride_C = std::stoi(argv[10]);
}
else
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl;
return false;
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Row;
using BLayout = Row; // Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
#define K1 4
#define K0 3
#define N1 2
#define B 64 // block-size:64
// clang-format off
using DeviceGemvInstance = ck::tensor_operation::device::deviceTsmmDl/*
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 64, 32, 2, 1, 1, 1, S<1, 1, 1, 2>, S<32, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 1>;*/
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, 1, B*N1, K0, K1, 1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1, 1, 1>,S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, 2>, S<0,1,2,3,4>, 3, N1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemv_splitk_example.inc"
int main(int argc, char* argv[]) { return !run_gemv_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC, k_batch] = problem_size; // //
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
#ifdef BUILD_INT4_EXAMPLE
DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
c_m_n_device_result.mDesc.GetElementSpaceSize());
const Tensor<KernelADataType> a_m_k_converted(a_m_k);
const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
#else
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
#endif
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemv = DeviceGemvInstance{};
auto invoker = gemv.MakeInvoker();
auto argument = gemv.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
k_batch); // //
// //
if(!gemv.IsSupportedArgument(argument))
{
std::cerr << gemv.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
c_m_n_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false}); // Run prior to verification
if(config.do_verification)
{
auto ref_gemv = ReferenceGemmInstance{};
auto ref_invoker = ref_gemv.MakeInvoker();
auto ref_argument = ref_gemv.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
#ifdef BUILD_INT4_EXAMPLE
Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);
c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
#else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
#endif
}
float ave_time = invoker.Run(
argument, StreamConfig{nullptr, config.time_kernel}); // Run to measure performance
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemv.GetTypeString() << std::endl;
#ifdef BUILD_INT4_EXAMPLE
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#endif
}
bool run_gemv_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
}
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[7]);
problem_size.stride_A = std::stoi(argv[8]);
problem_size.stride_B = std::stoi(argv[9]);
problem_size.stride_C = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4: splitk\n");
printf("arg5 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
return run_gemv(problem_size, config);
}
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_tall_and_skinny_gemm_splitk)
add_example_executable(example_tall_and_skinny_gemm_splitk_fp16 tall_and_skinny_gemm_splitk_fp16.cpp)
add_dependencies(example_tall_and_skinny_gemm_splitk
example_tall_and_skinny_gemm_splitk_fp16)
set(target 1)
endif()
endforeach()
\ No newline at end of file
# Instructions for ```example_gemv_splitk```
## Run ```example_gemv_splitk```
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
#arg4: number of splitk batches
bin/example_tall_and_skinny_gemm_splitk_fp16 1 2 1 231
```
Result (MI250 @ 800Mhz, 181.05TFlops peak FP16)
```
a_m_k: dim 2, lengths {16, 1024}, strides {1024, 1}
b_k_n: dim 2, lengths {1024, 16}, strides {16, 1}
c_m_n: dim 2, lengths {16, 16}, strides {16, 1}
Perf: 0.0065438 ms, 0.0801198 TFlops, 10.0932 GB/s, deviceTsmmDl<64, 16, 128, 4, 2, 16, 2, 1>
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
struct ProblemSize final // Default GEMV problem size
{
ck::index_t M = 16;
ck::index_t N = 16;
ck::index_t K = 1024;
ck::index_t stride_A = K;
ck::index_t stride_B = N; // K;
ck::index_t stride_C = N;
ck::index_t k_batch = 1;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
inline bool
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
{
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
}
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[7]);
problem_size.stride_A = std::stoi(argv[8]);
problem_size.stride_B = std::stoi(argv[9]);
problem_size.stride_C = std::stoi(argv[10]);
}
else
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl;
return false;
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool run_tall_and_skinny_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
using namespace ck::literals;
auto& [M, N, K, StrideA, StrideB, StrideC, k_batch] = problem_size; // //
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0: break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
#ifdef BUILD_INT4_EXAMPLE
DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
c_m_n_device_result.mDesc.GetElementSpaceSize());
const Tensor<KernelADataType> a_m_k_converted(a_m_k);
const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
#else
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
#endif
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto tsmm = DeviceTSMMInstance{};
auto invoker = tsmm.MakeInvoker();
auto argument = tsmm.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op,
k_batch); // //
// //
if(!tsmm.IsSupportedArgument(argument))
{
std::cerr << tsmm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
c_m_n_device_buf.SetZero();
if(config.do_verification)
{
invoker.Run(argument, StreamConfig{nullptr, false}); // Run prior to verification
auto ref_tsmm = ReferenceGemmInstance{};
auto ref_invoker = ref_tsmm.MakeInvoker();
auto ref_argument = ref_tsmm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
#ifdef BUILD_INT4_EXAMPLE
Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);
c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
#else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
#endif
}
float ave_time = invoker.Run(
argument, StreamConfig{nullptr, config.time_kernel}); // Run to measure performance
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< tsmm.GetTypeString() << std::endl;
#ifdef BUILD_INT4_EXAMPLE
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
#endif
}
bool run_tall_and_skinny_gemm_example(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
}
else if(argc == 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.k_batch = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[7]);
problem_size.stride_A = std::stoi(argv[8]);
problem_size.stride_B = std::stoi(argv[9]);
problem_size.stride_C = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4: splitk\n");
printf("arg5 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n");
exit(0);
}
return run_tall_and_skinny_gemm(problem_size, config);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_tall_and_skinny_gemm_splitk.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Row;
using BLayout = Row; // Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
#define K1 2
#define K0 4
#define N1 2
#define B 64 // block-size:64
#define M1 16
// clang-format off
using DeviceTSMMInstance = ck::tensor_operation::device::deviceTsmmDl/*
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer | ABlockTransfer| ABlockTransfer | BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|SrcVectorTensorLengths| SrcVectorTensor|DstVectorTensorLengths| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | KBatch_K0_M0_M1_K1| KBatch_K0_M0_M1_K1| ArrangeOrder| Order| KBatch_K0_M0_M1_K1 | ContiguousDimOrder| KBatch_K0_M0_M1_K1 | Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 64, 32, 2, 1, 1, 1, S<1, 1, 1, 2>, S<32, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 1>;*/
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, B, M1, B*N1, K0, K1, M1, N1, 1, S<1,1, 1, 1, K1>, S<1,K0, 1,M1, 1>, S<0,1,2,3,4>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, S<1,1, 1, 1, K1>, S<0,1,2,3,4>, 3, N1, S<0, 1, 2, 3, 4, 5>, 5, N1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_tall_and_skinny_gemm_splitk_example.inc"
int main(int argc, char* argv[]) { return !run_tall_and_skinny_gemm_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tall_and_skinny_gemm.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename ABlockDesc_K0_M_K1,
typename BThreadDesc_K0_N_K1,
index_t MPerThread,
index_t NPerBlock,
index_t K0PerLoop>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
using CIndex = MultiIndex<4>;
static constexpr auto K0 = ABlockDesc_K0_M_K1{}.GetLength(I0);
static constexpr auto M = ABlockDesc_K0_M_K1{}.GetLength(I1);
static constexpr auto K1 = ABlockDesc_K0_M_K1{}.GetLength(I2);
static constexpr auto NPerThread = BThreadDesc_K0_N_K1{}.GetLength(I1);
static constexpr auto M0 = M / MPerThread;
static constexpr auto M1 = MPerThread;
static constexpr auto N = NPerBlock;
static constexpr auto N0 = N / NPerThread;
static constexpr auto N1 = NPerThread;
static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<K0PerLoop>{}, Number<MPerThread>{}, Number<K1>{}));
static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<K0PerLoop>{}, Number<NPerThread>{}, Number<K1>{}));
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<I1>{}, Number<M1>{}, Number<I1>{}, Number<N1>{}));
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * MPerThread, 0)}
{
static_assert(ABlockDesc_K0_M_K1::IsKnownAtCompileTime() &&
BThreadDesc_K0_N_K1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ABlockDesc_K0_M_K1{}.GetLength(I0) == BThreadDesc_K0_N_K1{}.GetLength(I0) &&
ABlockDesc_K0_M_K1{}.GetLength(I2) == BThreadDesc_K0_N_K1{}.GetLength(I2),
"wrong! E dimension not consistent\n");
static_assert(K0 % K0PerLoop == 0, "");
static_assert(M % MPerThread == 0 && N % NPerThread == 0,
"wrong! Cannot evenly divide work among\n");
static_assert(BlockSize == M0 * N0, "wrong! wrong blocksize\n");
}
__device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
{
return Sequence<I1, M1, I1, N1>{};
}
__device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
{
constexpr auto c_threadid_to_m0_m1_n0_n1_thread_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, I1, N0, I1))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto c_m0_m1_n0_n1_thread_cluster_idx =
c_threadid_to_m0_m1_n0_n1_thread_cluster_adaptor.CalculateBottomIndex(
make_multi_index(thread_id));
return c_m0_m1_n0_n1_thread_cluster_idx;
}
template <typename ABlockBuffer, typename BThreadBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BThreadBuffer& b_thread_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(
is_same<remove_cvref_t<typename ABlockBuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BThreadBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto a_block_mtx = ABlockDesc_K0_M_K1{};
// thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
FloatB,
FloatC,
decltype(a_thread_mtx_),
decltype(b_thread_mtx_),
decltype(c_thread_mtx_)>{};
static_for<0, K0, K0PerLoop>{}([&](auto k0_begin) {
a_thread_copy_.Run(a_block_mtx,
make_tuple(k0_begin, I0, I0),
a_block_buf,
a_thread_mtx_,
make_tuple(I0, I0, I0),
a_thread_buf);
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(k0_begin, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
});
}
private:
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
ABlockDesc_K0_M_K1,
decltype(a_thread_mtx_),
Sequence<K0PerLoop, MPerThread, K1>,
Sequence<0, 1, 2>,
2,
K1,
K1>;
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
};
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceTsmm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_tall_and_skinny_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_tall_and_skinny_gemm_splitk.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
typename BThreadTransferSrcDstAccessOrder,
index_t BThreadTransferSrcVectorDim,
index_t BThreadTransferSrcScalarPerVector,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct deviceTsmmDl : public DeviceTsmm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
// GridwiseTsmm
using GridwiseTsmm =
GridwiseTsmmDl_km_kn_mn<BlockSize,
ADataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
GemmSpec,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
MPerThread,
NPerThread,
KPerThread,
ABlockTransferThreadSliceLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_KBatch_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_KBatch_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_KBatch_K0_M0_M1_K1,
BThreadTransferSrcDstAccessOrder,
BThreadTransferSrcVectorDim,
BThreadTransferSrcScalarPerVector,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using DefaultBlock2CTileMap = typename GridwiseTsmm::DefaultBlock2CTileMap;
using Argument = typename GridwiseTsmm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{
const index_t grid_size = GridwiseTsmm::CalculateGridSize(karg.M, karg.N, karg.k_batch);
// const auto b2c_map = DefaultBlock2CTileMap{};
const auto K0 = karg.K0;
const bool has_main_k_block_loop = GridwiseTsmm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseTsmm::CalculateHasDoubleTailKBlockLoop(K0);
float ave_time = 0;
if(karg.k_batch > 1)
hipGetErrorString(hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
true,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
true,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
true,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
true,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
false,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
false,
true,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
}
else
{
if(karg.k_batch == 1)
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::Set,
false,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel = kernel_tsmm_dl_v1r3<GridwiseTsmm,
ADataType,
CDataType,
InMemoryDataOperationEnum::AtomicAdd,
false,
false,
DefaultBlock2CTileMap>; // //
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, karg);
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
// //
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
{
return GridwiseTsmm::CheckValidity(arg);
}
else
{
return false;
}
}
// //
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
index_t KBatch) // //
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
// GridwiseTsmm::CalculateMPadded(M),
// GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm::CalculateK0(K, KBatch),
KBatch}; // //
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ck::index_t KBatch = 1) override // //
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
// GridwiseTsmm::CalculateMPadded(M),
// GridwiseTsmm::CalculateNPadded(N),
// GridwiseTsmm::CalculateKPadded(K, KBatch),
GridwiseTsmm::CalculateK0(K, KBatch),
KBatch); // //
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "deviceTsmmDl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< MPerThread << ", "
<< NPerThread << ", "
<< KPerThread
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP
#define CK_THREADWISE_GEMM_DLOPS_V3_HPP
#include "ck/utility/common_header.hpp"
namespace ck {
// C[M, N] += transpose(A[M, M]) * B[M, N]
// Element of matrix can be vectorized data
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AThreadDesc_K0_M_K1,
typename BThreadDesc_K0_N_K1,
typename CThreadDesc_M_N,
typename enable_if<AThreadDesc_K0_M_K1::IsKnownAtCompileTime() &&
BThreadDesc_K0_N_K1::IsKnownAtCompileTime() &&
CThreadDesc_M_N::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemmDlops_km_kn_mn_v3
{
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(AThreadDesc_K0_M_K1::IsKnownAtCompileTime() &&
BThreadDesc_K0_N_K1::IsKnownAtCompileTime() &&
CThreadDesc_M_N::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(
is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto K0 = AThreadDesc_K0_M_K1{}.GetLength(I0);
constexpr auto M = AThreadDesc_K0_M_K1{}.GetLength(I1);
constexpr auto K1 = AThreadDesc_K0_M_K1{}.GetLength(I2);
constexpr auto N = BThreadDesc_K0_N_K1{}.GetLength(I1);
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, M, 1>{}([&](auto m) {
static_for<0, N, 1>{}([&](auto n) {
static_for<0, K0, 1>{}([&](auto k0) {
static_for<0, K1, 1>{}([&](auto k1) {
constexpr index_t a_offset = AThreadDesc_K0_M_K1{}.CalculateOffset(
a_origin_idx + make_tuple(k0, m, k1));
constexpr index_t b_offset = BThreadDesc_K0_N_K1{}.CalculateOffset(
b_origin_idx + make_tuple(k0, n, k1));
constexpr index_t c_offset = CThreadDesc_M_N{}.CalculateOffset(
c_origin_idx + make_tuple(0, m, 0, n));
inner_product<FloatA, FloatB, FloatC>(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset>{}],
c_buf(Number<c_offset>{}));
});
});
});
});
} // namespace ck
};
} // namespace ck
#endif
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
set(GEMV_SPLITK_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND GEMV_SPLITK_INSTANCES device_gemv_splitk_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMV_SPLITK_INSTANCES device_gemv_splitk_f16_f16_f16_mk_nk_mn_instance.cpp)
endif()
add_instance_library(device_gemv_splitk_instance ${GEMV_SPLITK_INSTANCES})
set(target 1)
endif()
endforeach()
\ No newline at end of file
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
set(TALL_AND_SKINNY_GEMM_SPLITK_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND TALL_AND_SKINNY_GEMM_SPLITK_INSTANCES device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND TALL_AND_SKINNY_GEMM_SPLITK_INSTANCES device_tall_and_skinny_gemm_splitk_f16_f16_f16_mk_nk_mn_instance.cpp)
endif()
add_instance_library(device_tall_and_skinny_gemm_splitk_instance ${TALL_AND_SKINNY_GEMM_SPLITK_INSTANCES})
set(target 1)
endif()
endforeach()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment