"vscode:/vscode.git/clone" did not exist on "952f02ff0f09251c0882542e574b793bb634616e"
Commit 1b462ab5 authored by Adam Osewski's avatar Adam Osewski
Browse files

Clean up debug code and reuse new neighbour count func.

parent e954c206
......@@ -63,20 +63,19 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_splitk_v2(
const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
void* const __restrict__ p_workspace,
const index_t tile_count,
const index_t k_batch,
[[maybe_unused]] const AElementwiseOperation a_element_op,
[[maybe_unused]] const BElementwiseOperation b_element_op,
[[maybe_unused]] const CDEElementwiseOperation cde_element_op)
kernel_grouped_gemm_xdl_splitk_v2(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
void* const __restrict__ p_workspace,
const index_t tile_count,
const index_t k_batch,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
[[maybe_unused]] __shared__ uint8_t p_shared[shared_size];
__shared__ uint8_t p_shared[shared_size];
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
......@@ -105,12 +104,6 @@ __global__ void
index_t gemm_tile_id_end = grid_size_grp;
auto gridwise_gemm = GridwiseGemm();
[[maybe_unused]] auto is_thread_local_1d_id_idx = [](auto... Ids) -> bool
{
const auto tid = get_thread_local_1d_id();
return ((tid == Ids) || ... );
};
do
{
// Find corresponding GEMM group for our tile
......@@ -129,12 +122,12 @@ __global__ void
gemm_tile_id_end = offset + grid_size_grp;
}
[[maybe_unused]] const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid);
[[maybe_unused]] const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid);
const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid);
const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid);
[[maybe_unused]] const auto K = gemm_desc_ptr[group_id].K;
[[maybe_unused]] const auto StrideA = gemm_desc_ptr[group_id].StrideA;
[[maybe_unused]] const auto StrideB = gemm_desc_ptr[group_id].StrideB;
const auto K = gemm_desc_ptr[group_id].K;
const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB;
auto& results_buffer = gridwise_gemm.GetCThreadBuffer();
b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset);
......@@ -143,32 +136,21 @@ __global__ void
// Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile
// TODO: change desc so that few K-tiles will be done in single GEMM.
// {
// if (is_thread_local_1d_id_idx(0))
// {
// printf("bid: %d, group: %d, accumulate tile id (M,N,K): [%d, %d, %d] \n",
// static_cast<index_t>(blockIdx.x),
// group_id,
// b2c_tile_map.GetTileMIdx(),
// b2c_tile_map.GetTileNIdx(),
// b2c_tile_map.GetTileKIdx());
// }
// }
do
{
// just accumulate results in registers!
// gridwise_gemm.template RunGEMM<HasMainKBlockLoop>(p_a_grid,
// p_b_grid,
// static_cast<void*>(p_shared),
// a_element_op,
// b_element_op,
// M,
// N,
// K,
// StrideA,
// StrideB,
// k_batch,
// b2c_tile_map);
gridwise_gemm.template RunGEMM<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
static_cast<void*>(p_shared),
a_element_op,
b_element_op,
M,
N,
K,
StrideA,
StrideB,
k_batch,
b2c_tile_map);
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
......@@ -184,122 +166,47 @@ __global__ void
work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset);
// {
// // const uint32_t flag_v2 = __builtin_amdgcn_readfirstlane(
// // work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
if (is_thread_local_1d_id_idx(0))
{
printf("bid: %d, group: %d, FlagFInished \n",
static_cast<index_t>(blockIdx.x),
group_id);
// printf("bid: %d, group: %d, FlagFInished flag_v[%u]: %u\n",
// static_cast<index_t>(blockIdx.x),
// group_id)
// work_scheduler.GetWorkgroupFlagIdx(k_batch, output_tile_idx, output_tile_idx_offset),
// flag_v2);
}
// }
// The workgroup which processed first K tile accumulates results and stores to GMEM
if(b2c_tile_map.IsFirstKSplitBlock())
{
if (is_thread_local_1d_id_idx(0))
{
printf("bid: %d, group: %d, Will wait for neighbours... \n",
static_cast<index_t>(blockIdx.x),
group_id);
}
// Wait untill all other blocks for this [M,N] tile store their results.
work_scheduler.WaitForNeighbours(k_batch, output_tile_idx, output_tile_idx_offset);
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value.
[[maybe_unused]] const uint32_t flag_v = __builtin_amdgcn_readfirstlane(
work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
// {
// if (is_thread_local_1d_id_idx(0))
// {
// printf("bid: %d, group: %d, WaitForNeighbours flag_v[%u]: %u\n",
// static_cast<index_t>(blockIdx.x),
// group_id,
// work_scheduler.GetWorkgroupFlagIdx(k_batch, output_tile_idx, output_tile_idx_offset),
// static_cast<index_t>(blockIdx.x));
// // flag_v);
// }
// }
// using CThreadBuffer = remove_cvref_t<decltype(results_buffer)>;
// constexpr index_t n_v = CThreadBuffer::num_of_v_.value;
// constexpr index_t s_per_v = CThreadBuffer::s_per_v.value;
// static_for<0, n_v, 1>{}([&](auto v) {
// static_for<0, s_per_v, 1>{}([&](auto s) {
// // printf("bid: %d; tid: %d; [Partial results] c_thread_buff[%d, %d]:
// // %f\n",
// // static_cast<index_t>(blockIdx.x),
// // static_cast<index_t>(threadIdx.x),
// // v.value,
// // s.value,
// // static_cast<float>(results_buffer[v * Number<s_per_v>{} + s])
// // );
// results_buffer(v * Number<s_per_v>{} + s) = threadIdx.x * v + s;
// });
// });
index_t neighbour_count = work_scheduler.WaitForNeighbours(
k_batch, b2c_tile_map.GetTileKIdx(), output_tile_idx, output_tile_idx_offset);
// Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile.
// if(flag_v > 1)
// gridwise_gemm.AccumulatePartials(p_workspace, flag_v);
if (is_thread_local_1d_id_idx(0))
{
printf("bid: %d, group: %d, Reset flag \n",
static_cast<index_t>(blockIdx.x),
group_id);
}
if(neighbour_count > 1)
gridwise_gemm.AccumulatePartials(p_workspace, neighbour_count);
// Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
// const auto p_e_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_e_grid);
// const auto stride_e = gemm_desc_ptr[group_id].StrideE;
// const auto stride_ds = gemm_desc_ptr[group_id].StrideDs;
// constexpr auto NumDTensor = DsDataType::Size();
// using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
// DsGridPointer p_ds_grid;
// static_for<0, NumDTensor, 1>{}([&](auto i) {
// using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
// });
// gridwise_gemm.template RunWrite(p_ds_grid,
// p_e_grid,
// static_cast<void*>(p_shared),
// M,
// N,
// stride_ds,
// stride_e,
// cde_element_op,
// b2c_tile_map);
const auto p_e_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_e_grid);
const auto stride_e = gemm_desc_ptr[group_id].StrideE;
const auto stride_ds = gemm_desc_ptr[group_id].StrideDs;
constexpr auto NumDTensor = DsDataType::Size();
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
DsGridPointer p_ds_grid;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
});
gridwise_gemm.template RunWrite(p_ds_grid,
p_e_grid,
static_cast<void*>(p_shared),
M,
N,
stride_ds,
stride_e,
cde_element_op,
b2c_tile_map);
}
else if(work_scheduler.HasTile())
{
{
// const uint32_t flag_v2 = __builtin_amdgcn_readfirstlane(
const uint32_t flag_v2 = work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset);
if (is_thread_local_1d_id_idx(0))
{
printf("bid: %d, group: %d, Waiting for Reduction flag_v[%u]: %u\n",
static_cast<index_t>(blockIdx.x),
group_id,
work_scheduler.GetWorkgroupFlagIdx(k_batch, output_tile_idx, output_tile_idx_offset),
// static_cast<index_t>(blockIdx.x));
flag_v2);
}
}
work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset);
}
} while(work_scheduler.HasTile());
......@@ -839,8 +746,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void* p_flags = reinterpret_cast<char*>(dev_gemm_workspace) +
Block2ETileMapKSplit::GetAccWorkspaceSize(
sizeof(typename GridwiseGemm::AccType), grid_size);
// std::size_t flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
std::size_t flag_count = arg.tile_count_ / arg.K_BATCH;
std::size_t flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
if(stream_config.log_level_ > 0)
{
......@@ -1077,13 +983,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
int grid_size = std::min(arg.tile_count_, occ_grid_size);
int tiles_per_block = (arg.tile_count_ + grid_size - 1) / grid_size;
if(arg.tile_count_ > occ_grid_size &&
grid_size * tiles_per_block > arg.tile_count_)
if(arg.tile_count_ > occ_grid_size && grid_size * tiles_per_block > arg.tile_count_)
{
grid_size = (arg.tile_count_ + tiles_per_block - 1) / tiles_per_block;
}
// int flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
int flag_count = arg.tile_count_ / arg.K_BATCH;
int flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
// This would be the maximum needed workspace size. Since actual grid size, which determines
// the amount of workspace bytes needed, may be less due to the number of available CUs in
......
......@@ -106,13 +106,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
template <index_t... Ids>
__device__ static bool is_thread_local_1d_id_idx()
{
const auto tid = get_thread_local_1d_id();
return ((tid == Ids) || ...);
}
public:
using AccType = AccDataType;
......@@ -913,32 +906,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
Sequence<6>{},
Sequence<7>{}));
// if (is_thread_local_1d_id_idx<0>())
// {
// // printf("bid: %d; tid: %d; [Store Partials] c_block_desc:[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// // static_cast<index_t>(blockIdx.x),
// // static_cast<index_t>(threadIdx.x),
// // M0.value,
// // N0.value,
// // M1.value,
// // N1.value,
// // M2.value,
// // M3.value,
// // M4.value,
// // N2.value);
// printf("bid: %d; tid: %d; [Store Partials] wrkspace_desc:[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7).value);
// }
auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace);
auto w_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_workspace_grid, workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
......@@ -996,33 +963,11 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// if (is_thread_local_1d_id_idx<0, 64, 223>())
// {
// printf("[StorePartials] bid: %d, tid: %d: dst origin idx[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// (static_cast<index_t>(blockIdx.x)) * MXdlPerWave,
// n_thread_data_on_block_idx[I0],
// m_thread_data_on_block_idx[I1],
// n_thread_data_on_block_idx[I1],
// m_thread_data_on_block_idx[I2],
// m_thread_data_on_block_idx[I3],
// m_thread_data_on_block_idx[I4],
// n_thread_data_on_block_idx[I2]);
// }
c_thread_copy_vgpr_to_gmem.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
w_grid_buf);
if (is_thread_local_1d_id_idx<0>())
{
printf("[StorePartials] done. bid: %d, tid: %d\n",
static_cast<index_t>(blockIdx.x),
static_cast<index_t>(threadIdx.x));
}
}
__device__ void AccumulatePartials(void* __restrict__ p_workspace, uint32_t reduce_count)
......@@ -1158,7 +1103,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// We do not need to read this workgroup partial results since they're
// already in c_thread_buff
for(uint32_t i_t = 1; i_t < reduce_count; ++i_t)
for(uint32_t i_t = 1; i_t <= reduce_count; ++i_t)
{
acc_buf.Clear();
acc_load.Run(workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -16,96 +16,96 @@ namespace tensor_operation {
namespace device {
namespace instance {
// void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Row,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Col,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
// Row,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
// Col,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Col,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Row,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Col,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
......@@ -120,31 +120,31 @@ void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
PassThrough,
PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Row,
// Empty_Tuple,
// Row,
// F16,
// F8,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Row,
// Empty_Tuple,
// Row,
// F8,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F8,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F8,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout,
typename BLayout,
......@@ -186,48 +186,48 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
// add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
// add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
// add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
// add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
// add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
// op_ptrs);
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
}
// else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
// }
// else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
// }
}
// else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
// is_same_v<EDataType, half_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(op_ptrs);
// }
// }
// else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
// is_same_v<EDataType, half_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
// }
// }
return op_ptrs;
}
};
......
......@@ -17,18 +17,18 @@ namespace device {
namespace instance {
// MultiD version
// void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Col,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
......@@ -93,8 +93,8 @@ struct DeviceOperationInstanceFactory<
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
// add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances(
// op_ptrs);
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances(
op_ptrs);
}
}
return op_ptrs;
......
add_instance_library(device_grouped_gemm_instance
# device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
# device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
# device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
# device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
# device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
# device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
# device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
# device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
# device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
)
add_instance_library(device_grouped_gemm_multiple_d_instance
# device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
)
......@@ -39,7 +39,9 @@ bool profile_ggemm_multid_splitk(int do_verification,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1)
int kbatch = 1,
int warmup_iter = 1,
int kernel_iter = 10)
{
bool pass = true;
......@@ -250,23 +252,18 @@ bool profile_ggemm_multid_splitk(int do_verification,
for(std::size_t j = 0; j < kbatch_list.size(); j++)
{
auto kbatch_curr = kbatch_list[j];
// std::cout << ">>> kbatch: " << kbatch_curr << std::endl;
gptr->SetKBatchSize(argument_ptr.get(), kbatch_curr);
DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get()));
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(),
gemm_desc_workspace.GetDeviceBuffer());
// std::cout << "WorkspacePointer set!" << std::endl;
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
for(std::size_t i = 0; i < gemm_descs.size(); i++)
c_device_buf[i]->SetZero();
// invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 1});
// std::cout << ">>>>>GPU Run end!" << std::endl;
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
if(do_verification)
{
......@@ -313,13 +310,12 @@ bool profile_ggemm_multid_splitk(int do_verification,
std::cout << ">>>>>CPU verification end!" << std::endl;
}
if(time_kernel)
{
std::cout << ">>>>>GPU time profiling start!" << std::endl;
float avg_time = invoker_ptr->Run(
// argument_ptr.get(), StreamConfig{nullptr, time_kernel, 1, 5, 30});
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 1, 0, 1});
argument_ptr.get(),
StreamConfig{nullptr, time_kernel, 0, warmup_iter, kernel_iter});
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
......
......@@ -70,6 +70,8 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
<< "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "arg15: kbatch value (default 4)\n"
<< "arg16: warm-up iterations (default 1)\n"
<< "arg17: kernel repeat iterations (default 10)\n"
<< std::endl;
exit(1);
......@@ -86,10 +88,12 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
const auto Ns = argToIntArray(argv[9]);
const auto Ks = argToIntArray(argv[10]);
const auto StrideAs = argToIntArray(argv[11]);
const auto StrideBs = argToIntArray(argv[12]);
const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
const auto StrideAs = argToIntArray(argv[11]);
const auto StrideBs = argToIntArray(argv[12]);
const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
const int warmup_iter = argc == 16 ? std::stoi(argv[15]) : 1;
const int kernel_iter = argc == 17 ? std::stoi(argv[16]) : 10;
#ifdef CK_ENABLE_FP16
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
......@@ -110,7 +114,9 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
kbatch);
kbatch,
warmup_iter,
kernel_iter);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
......@@ -131,7 +137,9 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
kbatch);
kbatch,
warmup_iter,
kernel_iter);
}
else
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <memory>
#include <vector>
......@@ -150,18 +150,16 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
if(b2c_tile_map.IsFirstKSplitBlock())
{
// Wait untill all other blocks for this [M,N] tile store their results.
work_scheduler.WaitForNeighbours(k_batch, output_tile_idx, output_tile_idx_offset);
index_t neighbour_count = work_scheduler.WaitForNeighbours(
k_batch, b2c_tile_map.GetTileKIdx(), output_tile_idx, output_tile_idx_offset);
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value.
const uint32_t flag_v = __builtin_amdgcn_readfirstlane(
work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
for(uint32_t i = 1; i < flag_v; ++i)
for(index_t i = 1; i <= neighbour_count; ++i)
{
partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock +
i * MPerBlock * NPerBlock + get_thread_local_1d_id()];
}
// Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
......@@ -284,11 +282,10 @@ struct GroupedGemmStridedTileLoopReduce
DeviceMem gemm_workspace, gemm_flags;
// const index_t tiles_per_block = (tile_count + grid_size - 1) / grid_size;
const index_t tiles_per_block = (tile_count + grid_size - 1) / grid_size;
// This is the number of MN-output tiles which we cover with workgroups.
// We launch k_batch / tiles_per_block workgroups for each output tile.
// const index_t flag_count = (grid_size * tiles_per_block + k_batch - 1) / k_batch;
const index_t flag_count = tile_count / k_batch;
const index_t flag_count = (grid_size * tiles_per_block + k_batch - 1) / k_batch;
gemm_workspace.Realloc(grid_size * MPerBlock * NPerBlock * sizeof(float));
gemm_flags.Realloc(flag_count * sizeof(uint32_t));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment