Unverified Commit 75e622f0 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan Committed by GitHub
Browse files

Universal streamk with atomics (#1360)

* universal streamk with atomics with ckprofiler support. grid_size and streamk strategy are tunable. grid_size of -1 leads to #WGs = maximum occupancy X num_CUs. implementation supports many different streamk policies: 1-tile, 2-tile, 3-tile and 4-tile. streamk strategy of -1 leads to default streamk policy (4-tile). 

* Update README.md

* fixing clang-format issues

* removed conflicts in struct members between streamk and universal streamk

* corrected arg parsing for streamk and universal streamk

* added stream-k policies for 3 tile and 4 tile

* fixed argument type issue with parsing cmd args

* changes suggested in PR review are made- removing comments and correcting copyright

* file permissions updated

* added default value support for grid_size and streamk-policy selection set to -1

* print messages for arguments

* print messages for arguments

* print messages for arguments1
parent eaa870a1
File mode changed from 100644 to 100755
......@@ -22,6 +22,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2)
add_example_executable(example_gemm_xdl_fp16_streamk_v3 gemm_xdl_fp16_streamk_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_streamk_v3)
add_example_executable(example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3)
add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp)
......
......@@ -7,3 +7,21 @@
#arg3: run kernel # of times (>1)
./bin/example_gemm_xdl 0 1 5
```
# Instructions for ```example_gemm_xdl_fp16_streamk_v3```
## Run ```example_gemm_xdl_fp16_streamk_v3```
```bash
arg1: verification (0=no, 1=yes)
arg2: initialization (0=no init, 1=integer value, 2=decimal value)
arg3: time kernel (0=no, 1=yes)
arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)
arg11: Grid_size(-1 for max occupancy)
bin/example_gemm_xdl_fp16_streamk_v3 1 2 1 3840 4096 4096 4096 4096 4096 1 -1
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1}
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
problem {M:3840, N:4096, K:4096, SA:4096, SB:4096, SC:4096, MP:4032, NP:4096, KRead:4096, KP:4096, AK0:512, BK0:2048, MBlock: 18, NBlock: 16, Stream-K Selection:1, Grid size:-1}
Perf: 0.292022 ms, 441.23 TFlops, 330.348 GB/s, DeviceGemmXdlUniversal<MNPadding, RRR> BlkSize: 256, BlkTile: 224x256x64, WaveTile: 16x16, WaveMap: 7x8, VmemReadVec: 8x8, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3, BlkGemmPipelinePrefetchStages: 2
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -45,6 +45,19 @@ struct ProblemSizeStreamK final
ck::index_t NumSKBlocks = -1;
};
struct ProblemSizeStreamK_universal final
{
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
ck::index_t Grid_size = -1; // defaults to max occupancy
ck::index_t Streamk_sel = 1; // defaults to 1-tile SK
};
struct ProblemSizeSplitK final
{
......@@ -123,6 +136,57 @@ bool parse_cmd_args<ProblemSize>(int argc,
return true;
}
template <>
bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
char* argv[],
ProblemSizeStreamK_universal& problem_size,
ExecutionConfig& config)
{
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc >= 10)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.M = std::stoi(argv[4]);
problem_size.N = std::stoi(argv[5]);
problem_size.K = std::stoi(argv[6]);
problem_size.StrideA = std::stoi(argv[7]);
problem_size.StrideB = std::stoi(argv[8]);
problem_size.StrideC = std::stoi(argv[9]);
if(argc >= 11)
{
problem_size.Streamk_sel = std::stoi(argv[10]);
problem_size.Grid_size = std::stoi(argv[11]);
}
}
else
{
std::cerr
<< "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
<< "arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)"
<< "\narg11: Grid_size(-1 for max occupancy)" << std::endl;
return false;
}
return true;
}
template <>
bool parse_cmd_args<ProblemSizeStreamK>(int argc,
char* argv[],
......@@ -165,7 +229,8 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc,
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
<< "arg10: NumSKBlocks(optional)" << std::endl;
<< "arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)"
<< "\narg11: Grid_size(-1 for max occupancy)" << std::endl;
return false;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t;
using ALayout = Row;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off
using DeviceGemmV2_Streamk_Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
256,
224, 256,
64, 8, 2,
16, 16,
7, 8,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 8, 2, 0,
1, 2, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example_streamk_v2.inc"
int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
using namespace ck::literals;
auto M = problem_size.M;
auto N = problem_size.N;
auto K = problem_size.K;
auto StrideA = problem_size.StrideA;
auto StrideB = problem_size.StrideB;
auto StrideC = problem_size.StrideC;
auto Grid_size = problem_size.Grid_size;
auto Streamk_sel = problem_size.Streamk_sel;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
auto f_get_default_streamk_policy = [](ck::index_t streamk_sel) {
if(streamk_sel == -1)
{
return static_cast<std::size_t>(4);
}
else
return static_cast<std::size_t>(streamk_sel);
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
Streamk_sel = f_get_default_streamk_policy(Streamk_sel);
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 3:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
#ifdef BUILD_INT4_EXAMPLE
DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
c_m_n_device_result.mDesc.GetElementSpaceSize());
const Tensor<KernelADataType> a_m_k_converted(a_m_k);
const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
#else
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
#endif
DeviceMem workspace;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmV2_Streamk_Instance{};
auto invoker = gemm.MakeInvoker();
float ave_time = 0;
auto argument = gemm.MakeArgument(
#ifdef BUILD_INT4_EXAMPLE
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#else
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
#endif
M,
N,
K,
StrideA,
StrideB,
StrideC,
Streamk_sel,
Grid_size,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
bool pass = true;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1});
#ifdef BUILD_INT4_EXAMPLE
Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);
c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
#else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
#endif
}
if(config.time_kernel)
{
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
return pass;
}
bool run_gemm_universal_streamk_example(int argc, char* argv[])
{
ProblemSizeStreamK_universal problem_size;
ExecutionConfig config;
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemm_Streamk_V2 : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t Streamk_sel,
ck::index_t Grid_size,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -1404,4 +1404,326 @@ struct BlockToCTileMap_GemmStreamK
}
};
template <uint32_t MPerBlock_,
uint32_t NPerBlock_,
uint32_t KPerBlock_,
StreamKReductionStrategy ReductionStrategy_ = StreamKReductionStrategy::Atomic,
uint32_t TileSwizzleSubM_ = 8,
index_t GroupNum = 8,
index_t M01_ = 4>
struct BlockToCTileMap_GemmStreamK_v2
{
static constexpr uint32_t min_k_iters_per_sk_block = 2;
static constexpr uint32_t MPerBlock = MPerBlock_;
static constexpr uint32_t NPerBlock = NPerBlock_;
static constexpr uint32_t KPerBlock = KPerBlock_;
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_;
static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_;
//--------------------------------------
// pass to device
mutable uint32_t sk_num_blocks;
uint32_t sk_num_big_blocks;
uint32_t dp_start_block_idx;
uint32_t reduction_start_block_idx;
uint32_t k_iters_per_big_block;
MDiv2 n_tiles;
MDiv k_iters_per_tile;
MDiv equiv_tiles_big; // for reduction
MDiv equiv_tiles_little; // for reduction
// prefer construct on host
__host__ __device__ BlockToCTileMap_GemmStreamK_v2(
uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1)
{
// total output tiles
uint32_t num_tiles =
math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock);
k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock));
uint32_t dp_tiles, dp_num_blocks, sk_total_iters;
// default to regular DP GEMM if sk blocks == 0
if(streamk_sel == 0)
{
sk_num_blocks = 0;
dp_tiles = num_tiles;
sk_num_big_blocks = 0;
k_iters_per_big_block = 0;
dp_num_blocks = num_tiles; // all tile to be dp block
dp_start_block_idx = 0;
sk_total_iters = 0; // clear this tiles
}
// 2-tile sk + DP GEMM
else
{
// check if there's enough work for DP+ stream-k
bool bigEnough = num_tiles > grid_size;
// select between stream-k strategies
uint32_t sk_tiles = 0;
if(streamk_sel == 1) // 1 tile stream-k
{
sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles;
}
else if(streamk_sel == 2) // 2-tile stream-k
{
sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles;
}
else if(streamk_sel == 3) // 3-tile stream-k
{
sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size)
: num_tiles;
}
else if(streamk_sel == 4) // 4-tile stream-k
{
sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size)
: num_tiles;
}
sk_num_blocks = sk_tiles;
// remaining tiles are DP tiles
dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0;
sk_total_iters = k_iters_per_tile.get() * sk_tiles;
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
// we need to decide how many iters for each sk block
// let m = k_iters_per_sk_block
// some of the sk block (little) will cover m iters, some (big) will cover m+1
// we have
// 1) l + b = sk_blocks
// 2) l * m + b * (m + 1) = sk_total_iters
// => (l + b) * m + b = sk_total_iters
// => sk_blocks * m + b = sk_total_iters
// => b = sk_total_iters - m * sk_blocks
// NOTE: big could be zero
uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
k_iters_per_big_block = k_iters_per_sk_block + 1;
dp_num_blocks = dp_tiles;
dp_start_block_idx = sk_num_blocks;
}
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
// using multiple blocks for parallel reduction
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get());
uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get());
equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
}
}
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return M0 * N0;
}
__host__ __device__ uint32_t get_sk_total_iters() const
{
uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
(sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1);
return sk_total_iters;
}
__host__ __device__ uint32_t get_sk_tiles() const
{
// tiles for sk
uint32_t sk_total_iters = get_sk_total_iters();
return k_iters_per_tile.div(sk_total_iters);
}
__host__ __device__ index_t get_grid_dims() const
{
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
{
// return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1);
return reduction_start_block_idx + get_sk_tiles();
}
else
return reduction_start_block_idx;
}
__device__ uint32_t get_block_idx() const
{
// TODO: swizzle block index for better locality
return __builtin_amdgcn_readfirstlane(blockIdx.x);
}
__device__ void
get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const
{
if(block_idx < sk_num_big_blocks)
{
iter_start = block_idx * k_iters_per_big_block;
iter_end = iter_start + k_iters_per_big_block;
}
else if(block_idx < sk_num_blocks)
{
iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
(block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
iter_end = iter_start + (k_iters_per_big_block - 1);
}
else if(block_idx >= dp_start_block_idx)
{
uint32_t sk_total_iters = get_sk_total_iters();
uint32_t dp_iters_per_block = k_iters_per_tile.get();
iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
iter_end = iter_start + dp_iters_per_block;
}
}
__device__ uint32_t get_current_iter_length(uint32_t iter_start,
uint32_t iter_end,
uint32_t total_iter_length) const
{
uint32_t iter_length_mod, iter_length_quo /*unused*/;
k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
uint32_t current_iter_length = math::min(
iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length);
return current_iter_length;
}
__device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); }
__device__ void
get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const
{
k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
}
__device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const
{
uint32_t m_tile_idx, n_tile_idx;
uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock);
n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx);
// // swizzle tile
uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock);
uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m;
const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem))
? tile_swizzle_sub_m
: tile_swizzle_sub_m_rem;
uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m;
m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m;
uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value;
uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
n_tile_idx_with_adapt);
}
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
{
static constexpr uint32_t alignment = 128;
uint32_t acc_buffer_bytes =
MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes;
return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
}
__host__ __device__ uint32_t get_workspace_size_for_semaphore() const
{
return get_sk_tiles() * sizeof(uint32_t);
}
__host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const
{
return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore();
}
__host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_,
const MDiv& equiv_tiles_) const
{
uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1;
uint32_t quo_, rem_;
equiv_tiles_.divmod(tile_idx_, quo_, rem_);
return quo_ * max_equiv_tiles_ + rem_;
}
__host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_,
uint32_t iters_per_sk_block_) const
{
return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
1);
}
__host__ __device__ uint32_t get_total_acc_buffers() const
{
uint32_t tiles_cover_big_blocks =
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
uint32_t tiles_cover_little_blocks =
get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1);
uint32_t total_intersec_big =
get_tile_intersections(tiles_cover_big_blocks, equiv_tiles_big);
uint32_t total_intersec_little =
get_tile_intersections(tiles_cover_little_blocks, equiv_tiles_little);
return sk_num_blocks + total_intersec_big + total_intersec_little;
}
__device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const
{
// TODO: from big to little
uint32_t tiles_cover_big_blocks =
get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block);
if(tile_idx_ < tiles_cover_big_blocks)
{
uint32_t touched_sk_blocks =
(tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
k_iters_per_big_block;
uint32_t current_intersec = get_tile_intersections(tile_idx_, equiv_tiles_big);
return touched_sk_blocks + current_intersec;
}
else
{
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_;
uint32_t touched_sk_blocks =
(tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
iters_per_little_sk_block;
uint32_t current_intersec =
get_tile_intersections(tile_idx_little_reverse, equiv_tiles_little);
return get_total_acc_buffers() - (touched_sk_blocks + current_intersec);
}
}
__device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const
{
uint32_t iters_per_big_sk_block = k_iters_per_big_block;
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
if(block_idx_ < sk_num_big_blocks)
{
uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
k_iters_per_tile.get() - 1);
uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_big);
return block_idx_ + current_intersec;
}
else
{
uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
uint32_t touched_tiles = k_iters_per_tile.div(
block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_little);
return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec);
}
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Col,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Col,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Col,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Col,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Col,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Col,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Col,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Col,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Col,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Col,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemm_Streamk_V2<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceGemm_Streamk_V2<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
op_ptrs);
add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
# ONLY XDL_KERNELS
set(GEMM_UNIVERSAL_STREAMK_INSTANCES)
list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp)
add_instance_library(device_gemm_universal_streamk_instance ${GEMM_UNIVERSAL_STREAMK_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances<GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances<GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances<GemmMNPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances<Intrawave,
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances<Intrawave,
GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGemm_Streamk_V2<Row,
Row,
Row,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances<Intrawave,
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment