"sims/mem/rules.mk" did not exist on "56775bc1685edb3c30c488bf1bd95db03ef38c72"
Commit e7cde218 authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

changes suggested in PR review are made- removing comments and correcting copyright

parent 57a38a1d
File mode changed from 100644 to 100755
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -117,7 +117,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
if(stride == -1)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
......@@ -162,18 +162,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
#if 0
printf("B matrix:\n");
for (int in = 0; in < N; in++)
{
for (int ik = 0; ik < K; ik++)
{
printf("%02x ", *(reinterpret_cast<uint8_t*>(&b_k_n(ik,in))));
if(ik%8==7) printf("|");
}
printf("\n");
}
#endif
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -147,10 +147,8 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_)); // HS
hipGetErrorString(hipMemsetAsync(
arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_));
const auto Run = [&](const auto& kernel) {
dim3 grid_dim;
if(arg.Grid_size < 0)
......@@ -193,25 +191,13 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
grid_dim,
// dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_);
stream_config, run_flush_cache, kernel, grid_dim, dim3(BlockSize), 0, arg_);
}
else
{
ave_time = launch_and_time_kernel(stream_config,
kernel,
// dim3(gdx, gdy, gdz),
grid_dim,
dim3(BlockSize),
0,
arg);
ave_time = launch_and_time_kernel(
stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
}
};
......@@ -477,7 +463,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
BElementwiseOperation,
CElementwiseOperation)
{
// return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
return Argument{
p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size}; // HS
}
......
......@@ -1461,31 +1461,27 @@ struct BlockToCTileMap_GemmStreamK_v2
// check if there's enough work for DP+ stream-k
bool bigEnough = num_tiles > grid_size;
// select between 1 tile and 2 tile sk
// select between stream-k strategies
uint32_t sk_tiles = 0;
if(streamk_sel == 1)
if(streamk_sel == 1) // 1 tile stream-k
{
sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles;
}
else if(streamk_sel == 2)
else if(streamk_sel == 2) // 2-tile stream-k
{
sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles;
}
else if(streamk_sel == 3)
else if(streamk_sel == 3) // 3-tile stream-k
{
sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size)
: num_tiles;
}
else if(streamk_sel == 4)
else if(streamk_sel == 4) // 4-tile stream-k
{
sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size)
: num_tiles;
}
sk_num_blocks = sk_tiles;
// if(sk_tiles < sk_num_blocks)
// {
// sk_num_blocks = sk_tiles;
// }
// remaining tiles are DP tiles
dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0;
......@@ -1508,7 +1504,6 @@ struct BlockToCTileMap_GemmStreamK_v2
dp_num_blocks = dp_tiles;
dp_start_block_idx = sk_num_blocks;
// dp_start_block_idx = ((sk_num_blocks + grid_size - 1) / grid_size) * grid_size;
}
n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock));
......@@ -1523,7 +1518,8 @@ struct BlockToCTileMap_GemmStreamK_v2
equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get());
}
#if 0
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
printf("streamk_sel=%0d,grid_size=%0d, num_tiles:%d, dp_tiles:%d, sk_tiles:%u, "
"sk_num_blocks:%d,dp_num_blocks:%d,sk_num_big_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, "
......@@ -1531,8 +1527,6 @@ struct BlockToCTileMap_GemmStreamK_v2
" workspace(acc float):%u\n",
streamk_sel,
grid_size,
// occupancy,
// get_grid_dims(num_cu, occupancy).x,
num_tiles,
dp_tiles,
get_sk_tiles(),
......@@ -1546,7 +1540,7 @@ struct BlockToCTileMap_GemmStreamK_v2
k_iters_per_big_block,
reduction_start_block_idx,
get_workspace_size(sizeof(float)));
#endif
}
}
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
......@@ -1656,90 +1650,6 @@ struct BlockToCTileMap_GemmStreamK_v2
m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m,
n_tile_idx_with_adapt);
// adding gfx94x optimized
// index_t block_1d_id = tile_idx;
// const index_t N0 = n_tiles_value;
// const index_t M0 = math::integer_divide_ceil(n * m / m, MPerBlock);
// // index_t GroupNum = 8;
// // index_t M01_ = 4;
// if(M0 == 1)
// {
// return make_tuple(0, block_1d_id);
// }
// else if(N0 == 1)
// {
// return make_tuple(block_1d_id, 0);
// }
// // block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
// else
// {
// const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum);
// const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
// auto group_id_x = block_1d_id % GroupNum;
// auto group_id_y = block_1d_id / GroupNum;
// auto remap_block_1d_id =
// group_id_x <= big_group_num
// ? group_id_x * group_size + group_id_y
// : group_id_x * group_size + big_group_num - group_id_x + group_id_y;
// index_t idx_N0 = remap_block_1d_id % N0;
// index_t idx_M0 = remap_block_1d_id / N0;
// const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
// index_t idx_M00 = idx_M0 / M01_;
// index_t idx_M01 = idx_M0 % M01_;
// index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
// /**
// * idxN0
// *
// * |< mtx N >|
// *
// * NPerBlock NPerBlock NPerBlock NPerBlock
// * N_0 N_1 N_2 N_3
// * - |-----------|-----------|-----------|-----|-----|-
// * ^ | - - 0 |/----> 2 | | | |
// * | | | / | | | | | M_0 MPerBlock
// * | M | /| | | | | |
// * |-0---|---/-|-----|-----|-----------|-----|-----|-
// * | 1 | / | | | blockid | | |
// * idxM0 | | | / | V | 5 | | | M_1 MPerBlock
// * | - V 1 | - 3 | | | |
// * |-----------|-----------|-----------|-----|-----|-
// * mtx M | | | | | |
// * | | | | | | M_2 MPerBlock
// * | | | | | |
// * |-----------|-----------|-----------|-----|-----|-
// * | | | | | |
// * | | | | | | M_3 MPerBlock
// * | | | | | |
// * |-----------|-----------|-----------|-----|-----|-
// * V | | | | | |
// * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
// * | | | | | |
// * |-----------|-----------|-----------|-----|-----|-
// * Example:
// * assume:
// * M0 = 5
// * N0 = 4
// * block_1d_id = 5
// * M01 = 2
// *
// * idx_N0 = 1
// * idx_M0 = 1
// * M01_adapt = 2
// * idx_M00 = 0
// * idx_M01 = 1
// * idx_N0_M01_local = 5
// * output {1, 2}
// */
// return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
// idx_N0_M01_local / M01_adapt);
//}
}
__host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const
......
......@@ -32,22 +32,13 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
// karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
p_shared,
karg);
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
......@@ -62,7 +53,6 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
......@@ -71,17 +61,8 @@ __global__ void
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
// karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
p_shared_0,
p_shared_1,
karg);
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
......@@ -155,15 +136,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// __host__ static auto CalculateGridSize(index_t M, index_t N) //, index_t KBatch)
// {
// // return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
// // return ((Block2CTileMap::CalculateGridSize(M, N)) * KBatch);
// // return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
// return Block2CTileMap::CalculateGridSize(M, N);
// }
__host__ static auto CalculateMPadded(index_t M)
{
return math::integer_least_multiple(M, MPerBlock);
......@@ -995,10 +967,8 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
else
{
// constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
// auto K_t = KReadVec;
// auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
if(karg.K <= 0) // HS
if(karg.K <= 0)
{
return false;
}
......@@ -1103,10 +1073,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
// if(karg.KBatch > 1)
// {
// return false;
// }
}
// check gridwise gemm pipeline
......@@ -1152,16 +1118,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return c_grid_desc_mblock_mperblock_nblock_nperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
// using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2<MPerBlock,
NPerBlock,
KPerBlock,
StreamKReductionStrategy::Atomic,
8,
4>; // HS
4>;
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
......@@ -1177,43 +1139,39 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// Provide a value for TileSwizzleSubM_
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
problem.N,
AK0Number * problem.KPadded,
problem.Grid_size,
problem.Streamk_sel); // HS
uint32_t iter_start, iter_end; // HS
bool is_sk_block, is_dp_block; //, is_padding_block; //, is_reduction_block; // HS
index_t num_k_block_main_loop; // HS
problem.Streamk_sel);
uint32_t iter_start, iter_end;
bool is_sk_block, is_dp_block;
index_t num_k_block_main_loop;
for(auto block_idx = get_block_1d_id();
block_idx < block_2_ctile_map_streamk.get_grid_dims();
block_idx += gridDim.x)
{
// for(unsigned int kbatch_id = 0; kbatch_id < static_cast<unsigned
// int>(problem.KBatch);
// kbatch_id++)
is_sk_block =
static_cast<uint32_t>(block_idx) < block_2_ctile_map_streamk.sk_num_blocks;
is_dp_block =
static_cast<uint32_t>(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx &&
static_cast<uint32_t>(block_idx) <
block_2_ctile_map_streamk.reduction_start_block_idx; // HS
block_2_ctile_map_streamk.reduction_start_block_idx;
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); // HS
num_k_block_main_loop = iter_end - iter_start; // HS
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
num_k_block_main_loop = iter_end - iter_start;
while(true)
{
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
block_2_ctile_map_streamk.get_current_iter_length(
iter_start, iter_end, num_k_block_main_loop)); // HS
uint32_t tile_idx, iter_offset; // HS
iter_start, iter_end, num_k_block_main_loop));
uint32_t tile_idx, iter_offset;
block_2_ctile_map_streamk.get_tile_idx_with_offset(
iter_end - 1, tile_idx, iter_offset); // HS
iter_offset =
__builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); // HS
iter_end - 1, tile_idx, iter_offset);
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M,
problem.MPadded,
......@@ -1237,17 +1195,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid /*+ splitk_batch_offset.a_k_split_offset*/,
a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid /*+ splitk_batch_offset.b_k_split_offset*/,
b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
// const auto block_work_idx =
// block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_idx));
auto block_work_idx =
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); // HS
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
......@@ -1260,7 +1214,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
const index_t k0_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(iter_offset * AK0Number); // HS
__builtin_amdgcn_readfirstlane(iter_offset * AK0Number);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
......@@ -1298,7 +1252,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), // HS
make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
......@@ -1361,7 +1315,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock); // HS:AK0*KPadded/KPerBlock
KPerBlock); :AK0*KPadded/KPerBlock
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
......@@ -1607,7 +1561,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if(iter_end <= iter_start)
break;
// make sure next loop LDS is ready for use
block_sync_lds(); // HS
block_sync_lds();
}
}
}
......@@ -1627,13 +1581,11 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M,
problem.N,
AK0Number * problem.KPadded,
problem.Grid_size); // HS
uint32_t iter_start, iter_end; // HS
bool is_sk_block, is_dp_block; //, is_padding_block; //, is_reduction_block; // HS
index_t num_k_block_main_loop; // HS
Block2CTileMap_streamk block_2_ctile_map_streamk(
problem.M, problem.N, AK0Number * problem.KPadded, problem.Grid_size);
uint32_t iter_start, iter_end;
bool is_sk_block, is_dp_block; //, is_padding_block; //, is_reduction_block;
index_t num_k_block_main_loop;
for(auto block_idx = get_block_1d_id();
block_idx < block_2_ctile_map_streamk.get_grid_dims();
......@@ -1644,21 +1596,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
is_dp_block =
static_cast<uint32_t>(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx &&
static_cast<uint32_t>(block_idx) <
block_2_ctile_map_streamk.reduction_start_block_idx; // HS
block_2_ctile_map_streamk.reduction_start_block_idx;
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); // HS
num_k_block_main_loop = iter_end - iter_start; // HS
block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end);
num_k_block_main_loop = iter_end - iter_start;
{
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
block_2_ctile_map_streamk.get_current_iter_length(
iter_start, iter_end, num_k_block_main_loop)); // HS
uint32_t tile_idx, iter_offset; // HS
iter_start, iter_end, num_k_block_main_loop));
uint32_t tile_idx, iter_offset;
block_2_ctile_map_streamk.get_tile_idx_with_offset(
iter_end - 1, tile_idx, iter_offset); // HS
iter_offset =
__builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); // HS
iter_end - 1, tile_idx, iter_offset);
iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(problem.M,
problem.MPadded,
......@@ -1683,16 +1634,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid /*+ splitk_batch_offset.a_k_split_offset*/,
a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid /*+ splitk_batch_offset.b_k_split_offset*/,
b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
// const auto block_work_idx =
// block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_idx));
auto block_work_idx =
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); // HS
block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N);
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
......@@ -1704,7 +1651,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
const index_t k0_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(iter_offset * AK0Number); // HS
__builtin_amdgcn_readfirstlane(iter_offset * AK0Number);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
......@@ -1742,7 +1689,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), // HS
make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
......@@ -1773,7 +1720,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0), // HS
make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -237,306 +237,6 @@ void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpaddin
PassThrough,
PassThrough>>>& instances);
#endif
// #if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// #endif
// #ifdef CK_ENABLE_FP16
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
// std::vector<std::unique_ptr<
// DeviceGemm_Streamk_V2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough,
// PassThrough>>>& instances);
// #endif
template <typename ADataType,
typename BDataType,
typename CDataType,
......@@ -626,158 +326,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemm_S
}
}
#endif
// #if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
// if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
// is_same_v<CDataType, half_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
// is_same_v<CLayout, Row>)
// {
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
// op_ptrs);
// }
// else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<CLayout, Row>)
// {
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
// op_ptrs);
// }
// }
// else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
// is_same_v<CDataType, half_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
// is_same_v<CLayout, Row>)
// {
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
// op_ptrs);
// }
// else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<CLayout, Row>)
// {
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances(op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
// op_ptrs);
// }
// }
// #endif
// #ifdef CK_ENABLE_FP16
// if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
// is_same_v<CDataType, bhalf_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<CLayout, Row>)
// {
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances(
// op_ptrs);
// }
// else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
// is_same_v<CLayout, Row>)
// {
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances(
// op_ptrs);
// add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instances(
// op_ptrs);
// }
// }
// #endif
return op_ptrs;
}
};
......
......@@ -21,70 +21,6 @@ list(APPEND GEMM_UNIVERSAL_STREAMK_INSTANCES
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
# device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
# device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
# device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
)
device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp)
add_instance_library(device_gemm_universal_streamk_instance ${GEMM_UNIVERSAL_STREAMK_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment