Commit 1b462ab5 authored by Adam Osewski's avatar Adam Osewski
Browse files

Clean up debug code and reuse new neighbour count func.

parent e954c206
...@@ -63,20 +63,19 @@ __global__ void ...@@ -63,20 +63,19 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdl_splitk_v2( kernel_grouped_gemm_xdl_splitk_v2(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
void* const __restrict__ p_workspace, void* const __restrict__ p_workspace,
const index_t tile_count, const index_t tile_count,
const index_t k_batch, const index_t k_batch,
[[maybe_unused]] const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
[[maybe_unused]] const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
[[maybe_unused]] const CDEElementwiseOperation cde_element_op) const CDEElementwiseOperation cde_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
[[maybe_unused]] __shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
const auto gemm_desc_ptr = const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const)); reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
...@@ -105,12 +104,6 @@ __global__ void ...@@ -105,12 +104,6 @@ __global__ void
index_t gemm_tile_id_end = grid_size_grp; index_t gemm_tile_id_end = grid_size_grp;
auto gridwise_gemm = GridwiseGemm(); auto gridwise_gemm = GridwiseGemm();
[[maybe_unused]] auto is_thread_local_1d_id_idx = [](auto... Ids) -> bool
{
const auto tid = get_thread_local_1d_id();
return ((tid == Ids) || ... );
};
do do
{ {
// Find corresponding GEMM group for our tile // Find corresponding GEMM group for our tile
...@@ -129,12 +122,12 @@ __global__ void ...@@ -129,12 +122,12 @@ __global__ void
gemm_tile_id_end = offset + grid_size_grp; gemm_tile_id_end = offset + grid_size_grp;
} }
[[maybe_unused]] const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid); const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid);
[[maybe_unused]] const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid); const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid);
[[maybe_unused]] const auto K = gemm_desc_ptr[group_id].K; const auto K = gemm_desc_ptr[group_id].K;
[[maybe_unused]] const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
[[maybe_unused]] const auto StrideB = gemm_desc_ptr[group_id].StrideB; const auto StrideB = gemm_desc_ptr[group_id].StrideB;
auto& results_buffer = gridwise_gemm.GetCThreadBuffer(); auto& results_buffer = gridwise_gemm.GetCThreadBuffer();
b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset); b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset);
...@@ -143,32 +136,21 @@ __global__ void ...@@ -143,32 +136,21 @@ __global__ void
// Iterate over K dimension for this [M,N] tile // Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile // still in the same GEMM && the same [M,N] tile
// TODO: change desc so that few K-tiles will be done in single GEMM. // TODO: change desc so that few K-tiles will be done in single GEMM.
// {
// if (is_thread_local_1d_id_idx(0))
// {
// printf("bid: %d, group: %d, accumulate tile id (M,N,K): [%d, %d, %d] \n",
// static_cast<index_t>(blockIdx.x),
// group_id,
// b2c_tile_map.GetTileMIdx(),
// b2c_tile_map.GetTileNIdx(),
// b2c_tile_map.GetTileKIdx());
// }
// }
do do
{ {
// just accumulate results in registers! // just accumulate results in registers!
// gridwise_gemm.template RunGEMM<HasMainKBlockLoop>(p_a_grid, gridwise_gemm.template RunGEMM<HasMainKBlockLoop>(p_a_grid,
// p_b_grid, p_b_grid,
// static_cast<void*>(p_shared), static_cast<void*>(p_shared),
// a_element_op, a_element_op,
// b_element_op, b_element_op,
// M, M,
// N, N,
// K, K,
// StrideA, StrideA,
// StrideB, StrideB,
// k_batch, k_batch,
// b2c_tile_map); b2c_tile_map);
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx()); } while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
...@@ -184,122 +166,47 @@ __global__ void ...@@ -184,122 +166,47 @@ __global__ void
work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset);
// {
// // const uint32_t flag_v2 = __builtin_amdgcn_readfirstlane(
// // work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
if (is_thread_local_1d_id_idx(0))
{
printf("bid: %d, group: %d, FlagFInished \n",
static_cast<index_t>(blockIdx.x),
group_id);
// printf("bid: %d, group: %d, FlagFInished flag_v[%u]: %u\n",
// static_cast<index_t>(blockIdx.x),
// group_id)
// work_scheduler.GetWorkgroupFlagIdx(k_batch, output_tile_idx, output_tile_idx_offset),
// flag_v2);
}
// }
// The workgroup which processed first K tile accumulates results and stores to GMEM // The workgroup which processed first K tile accumulates results and stores to GMEM
if(b2c_tile_map.IsFirstKSplitBlock()) if(b2c_tile_map.IsFirstKSplitBlock())
{ {
if (is_thread_local_1d_id_idx(0))
{
printf("bid: %d, group: %d, Will wait for neighbours... \n",
static_cast<index_t>(blockIdx.x),
group_id);
}
// Wait untill all other blocks for this [M,N] tile store their results. // Wait untill all other blocks for this [M,N] tile store their results.
work_scheduler.WaitForNeighbours(k_batch, output_tile_idx, output_tile_idx_offset); index_t neighbour_count = work_scheduler.WaitForNeighbours(
k_batch, b2c_tile_map.GetTileKIdx(), output_tile_idx, output_tile_idx_offset);
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value.
[[maybe_unused]] const uint32_t flag_v = __builtin_amdgcn_readfirstlane(
work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
// {
// if (is_thread_local_1d_id_idx(0))
// {
// printf("bid: %d, group: %d, WaitForNeighbours flag_v[%u]: %u\n",
// static_cast<index_t>(blockIdx.x),
// group_id,
// work_scheduler.GetWorkgroupFlagIdx(k_batch, output_tile_idx, output_tile_idx_offset),
// static_cast<index_t>(blockIdx.x));
// // flag_v);
// }
// }
// using CThreadBuffer = remove_cvref_t<decltype(results_buffer)>;
// constexpr index_t n_v = CThreadBuffer::num_of_v_.value;
// constexpr index_t s_per_v = CThreadBuffer::s_per_v.value;
// static_for<0, n_v, 1>{}([&](auto v) {
// static_for<0, s_per_v, 1>{}([&](auto s) {
// // printf("bid: %d; tid: %d; [Partial results] c_thread_buff[%d, %d]:
// // %f\n",
// // static_cast<index_t>(blockIdx.x),
// // static_cast<index_t>(threadIdx.x),
// // v.value,
// // s.value,
// // static_cast<float>(results_buffer[v * Number<s_per_v>{} + s])
// // );
// results_buffer(v * Number<s_per_v>{} + s) = threadIdx.x * v + s;
// });
// });
// Accumulate only when there is at least two workgroups processing splitk data-tiles // Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile. // across same MN-output tile.
// if(flag_v > 1) if(neighbour_count > 1)
// gridwise_gemm.AccumulatePartials(p_workspace, flag_v); gridwise_gemm.AccumulatePartials(p_workspace, neighbour_count);
if (is_thread_local_1d_id_idx(0))
{
printf("bid: %d, group: %d, Reset flag \n",
static_cast<index_t>(blockIdx.x),
group_id);
}
// Signal waiting blocks that they can start use their workspace. // Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
// const auto p_e_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_e_grid); const auto p_e_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_e_grid);
// const auto stride_e = gemm_desc_ptr[group_id].StrideE; const auto stride_e = gemm_desc_ptr[group_id].StrideE;
// const auto stride_ds = gemm_desc_ptr[group_id].StrideDs; const auto stride_ds = gemm_desc_ptr[group_id].StrideDs;
// constexpr auto NumDTensor = DsDataType::Size(); constexpr auto NumDTensor = DsDataType::Size();
// using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
// DsGridPointer p_ds_grid; DsGridPointer p_ds_grid;
// static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
// using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]); p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
// }); });
// gridwise_gemm.template RunWrite(p_ds_grid, gridwise_gemm.template RunWrite(p_ds_grid,
// p_e_grid, p_e_grid,
// static_cast<void*>(p_shared), static_cast<void*>(p_shared),
// M, M,
// N, N,
// stride_ds, stride_ds,
// stride_e, stride_e,
// cde_element_op, cde_element_op,
// b2c_tile_map); b2c_tile_map);
} }
else if(work_scheduler.HasTile()) else if(work_scheduler.HasTile())
{ {
{
// const uint32_t flag_v2 = __builtin_amdgcn_readfirstlane(
const uint32_t flag_v2 = work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset);
if (is_thread_local_1d_id_idx(0))
{
printf("bid: %d, group: %d, Waiting for Reduction flag_v[%u]: %u\n",
static_cast<index_t>(blockIdx.x),
group_id,
work_scheduler.GetWorkgroupFlagIdx(k_batch, output_tile_idx, output_tile_idx_offset),
// static_cast<index_t>(blockIdx.x));
flag_v2);
}
}
work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset);
} }
} while(work_scheduler.HasTile()); } while(work_scheduler.HasTile());
...@@ -839,8 +746,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -839,8 +746,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void* p_flags = reinterpret_cast<char*>(dev_gemm_workspace) + void* p_flags = reinterpret_cast<char*>(dev_gemm_workspace) +
Block2ETileMapKSplit::GetAccWorkspaceSize( Block2ETileMapKSplit::GetAccWorkspaceSize(
sizeof(typename GridwiseGemm::AccType), grid_size); sizeof(typename GridwiseGemm::AccType), grid_size);
// std::size_t flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH; std::size_t flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
std::size_t flag_count = arg.tile_count_ / arg.K_BATCH;
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
...@@ -1077,13 +983,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -1077,13 +983,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
int grid_size = std::min(arg.tile_count_, occ_grid_size); int grid_size = std::min(arg.tile_count_, occ_grid_size);
int tiles_per_block = (arg.tile_count_ + grid_size - 1) / grid_size; int tiles_per_block = (arg.tile_count_ + grid_size - 1) / grid_size;
if(arg.tile_count_ > occ_grid_size && if(arg.tile_count_ > occ_grid_size && grid_size * tiles_per_block > arg.tile_count_)
grid_size * tiles_per_block > arg.tile_count_)
{ {
grid_size = (arg.tile_count_ + tiles_per_block - 1) / tiles_per_block; grid_size = (arg.tile_count_ + tiles_per_block - 1) / tiles_per_block;
} }
// int flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH; int flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
int flag_count = arg.tile_count_ / arg.K_BATCH;
// This would be the maximum needed workspace size. Since actual grid size, which determines // This would be the maximum needed workspace size. Since actual grid size, which determines
// the amount of workspace bytes needed, may be less due to the number of available CUs in // the amount of workspace bytes needed, may be less due to the number of available CUs in
......
...@@ -106,13 +106,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -106,13 +106,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
using GridwiseGemmPipe = remove_cvref_t< using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
template <index_t... Ids>
__device__ static bool is_thread_local_1d_id_idx()
{
const auto tid = get_thread_local_1d_id();
return ((tid == Ids) || ...);
}
public: public:
using AccType = AccDataType; using AccType = AccDataType;
...@@ -913,32 +906,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -913,32 +906,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
Sequence<6>{}, Sequence<6>{},
Sequence<7>{})); Sequence<7>{}));
// if (is_thread_local_1d_id_idx<0>())
// {
// // printf("bid: %d; tid: %d; [Store Partials] c_block_desc:[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// // static_cast<index_t>(blockIdx.x),
// // static_cast<index_t>(threadIdx.x),
// // M0.value,
// // N0.value,
// // M1.value,
// // N1.value,
// // M2.value,
// // M3.value,
// // M4.value,
// // N2.value);
// printf("bid: %d; tid: %d; [Store Partials] wrkspace_desc:[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7).value);
// }
auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace); auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace);
auto w_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto w_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_workspace_grid, workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); p_workspace_grid, workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
...@@ -996,33 +963,11 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -996,33 +963,11 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
n_thread_data_on_block_idx[I2]), n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}}; ck::tensor_operation::element_wise::PassThrough{}};
// if (is_thread_local_1d_id_idx<0, 64, 223>())
// {
// printf("[StorePartials] bid: %d, tid: %d: dst origin idx[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// (static_cast<index_t>(blockIdx.x)) * MXdlPerWave,
// n_thread_data_on_block_idx[I0],
// m_thread_data_on_block_idx[I1],
// n_thread_data_on_block_idx[I1],
// m_thread_data_on_block_idx[I2],
// m_thread_data_on_block_idx[I3],
// m_thread_data_on_block_idx[I4],
// n_thread_data_on_block_idx[I2]);
// }
c_thread_copy_vgpr_to_gmem.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_thread_copy_vgpr_to_gmem.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
w_grid_buf); w_grid_buf);
if (is_thread_local_1d_id_idx<0>())
{
printf("[StorePartials] done. bid: %d, tid: %d\n",
static_cast<index_t>(blockIdx.x),
static_cast<index_t>(threadIdx.x));
}
} }
__device__ void AccumulatePartials(void* __restrict__ p_workspace, uint32_t reduce_count) __device__ void AccumulatePartials(void* __restrict__ p_workspace, uint32_t reduce_count)
...@@ -1158,7 +1103,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2 ...@@ -1158,7 +1103,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// We do not need to read this workgroup partial results since they're // We do not need to read this workgroup partial results since they're
// already in c_thread_buff // already in c_thread_buff
for(uint32_t i_t = 1; i_t < reduce_count; ++i_t) for(uint32_t i_t = 1; i_t <= reduce_count; ++i_t)
{ {
acc_buf.Clear(); acc_buf.Clear();
acc_load.Run(workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, acc_load.Run(workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -16,96 +16,96 @@ namespace tensor_operation { ...@@ -16,96 +16,96 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Row, Row,
// Empty_Tuple, Empty_Tuple,
// Row, Row,
// F16, F16,
// F16, F16,
// Empty_Tuple, Empty_Tuple,
// F16, F16,
// PassThrough, PassThrough,
// PassThrough, PassThrough,
// PassThrough>>>& instances); PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Col, Col,
// Empty_Tuple, Empty_Tuple,
// Row, Row,
// F16, F16,
// F16, F16,
// Empty_Tuple, Empty_Tuple,
// F16, F16,
// PassThrough, PassThrough,
// PassThrough, PassThrough,
// PassThrough>>>& instances); PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Col, std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
// Row, Row,
// Empty_Tuple, Empty_Tuple,
// Row, Row,
// F16, F16,
// F16, F16,
// Empty_Tuple, Empty_Tuple,
// F16, F16,
// PassThrough, PassThrough,
// PassThrough, PassThrough,
// PassThrough>>>& instances); PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances( void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Col, std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
// Col, Col,
// Empty_Tuple, Empty_Tuple,
// Row, Row,
// F16, F16,
// F16, F16,
// Empty_Tuple, Empty_Tuple,
// F16, F16,
// PassThrough, PassThrough,
// PassThrough, PassThrough,
// PassThrough>>>& instances); PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Col, Col,
// Empty_Tuple, Empty_Tuple,
// Row, Row,
// F16, F16,
// F16, F16,
// Empty_Tuple, Empty_Tuple,
// F16, F16,
// PassThrough, PassThrough,
// PassThrough, PassThrough,
// PassThrough>>>& instances); PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Row, Row,
// Empty_Tuple, Empty_Tuple,
// Row, Row,
// F16, F16,
// F16, F16,
// Empty_Tuple, Empty_Tuple,
// F16, F16,
// PassThrough, PassThrough,
// PassThrough, PassThrough,
// PassThrough>>>& instances); PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances( void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Col, Col,
// Empty_Tuple, Empty_Tuple,
// Row, Row,
// F16, F16,
// F16, F16,
// Empty_Tuple, Empty_Tuple,
// F16, F16,
// PassThrough, PassThrough,
// PassThrough, PassThrough,
// PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances( void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
...@@ -120,31 +120,31 @@ void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances ...@@ -120,31 +120,31 @@ void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances( void add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Row, Row,
// Empty_Tuple, Empty_Tuple,
// Row, Row,
// F16, F16,
// F8, F8,
// Empty_Tuple, Empty_Tuple,
// F16, F16,
// PassThrough, PassThrough,
// PassThrough, PassThrough,
// PassThrough>>>& instances); PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances( void add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Row, Row,
// Empty_Tuple, Empty_Tuple,
// Row, Row,
// F8, F8,
// F16, F16,
// Empty_Tuple, Empty_Tuple,
// F16, F16,
// PassThrough, PassThrough,
// PassThrough, PassThrough,
// PassThrough>>>& instances); PassThrough>>>& instances);
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
...@@ -186,48 +186,48 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -186,48 +186,48 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
// add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
// add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances( add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
op_ptrs); op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
// add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
// add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
// add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances( add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
// op_ptrs); op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
} }
// else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
// }
// else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
// }
} }
// else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
// is_same_v<EDataType, half_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(op_ptrs);
// }
// }
// else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
// is_same_v<EDataType, half_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
// }
// }
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -17,18 +17,18 @@ namespace device { ...@@ -17,18 +17,18 @@ namespace device {
namespace instance { namespace instance {
// MultiD version // MultiD version
// void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances( void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Col, Col,
// Empty_Tuple, Empty_Tuple,
// Row, Row,
// F16, F16,
// F16, F16,
// Empty_Tuple, Empty_Tuple,
// F16, F16,
// PassThrough, PassThrough,
// PassThrough, PassThrough,
// PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances( void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row, std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
...@@ -93,8 +93,8 @@ struct DeviceOperationInstanceFactory< ...@@ -93,8 +93,8 @@ struct DeviceOperationInstanceFactory<
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
// add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances( add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances(
// op_ptrs); op_ptrs);
} }
} }
return op_ptrs; return op_ptrs;
......
add_instance_library(device_grouped_gemm_instance add_instance_library(device_grouped_gemm_instance
# device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
# device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
# device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
# device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
# device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
# device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
# device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
# device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
# device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
) )
add_instance_library(device_grouped_gemm_multiple_d_instance add_instance_library(device_grouped_gemm_multiple_d_instance
# device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
) )
...@@ -39,7 +39,9 @@ bool profile_ggemm_multid_splitk(int do_verification, ...@@ -39,7 +39,9 @@ bool profile_ggemm_multid_splitk(int do_verification,
const std::vector<int>& StrideAs, const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs, const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs, const std::vector<int>& StrideCs,
int kbatch = 1) int kbatch = 1,
int warmup_iter = 1,
int kernel_iter = 10)
{ {
bool pass = true; bool pass = true;
...@@ -250,23 +252,18 @@ bool profile_ggemm_multid_splitk(int do_verification, ...@@ -250,23 +252,18 @@ bool profile_ggemm_multid_splitk(int do_verification,
for(std::size_t j = 0; j < kbatch_list.size(); j++) for(std::size_t j = 0; j < kbatch_list.size(); j++)
{ {
auto kbatch_curr = kbatch_list[j]; auto kbatch_curr = kbatch_list[j];
// std::cout << ">>> kbatch: " << kbatch_curr << std::endl;
gptr->SetKBatchSize(argument_ptr.get(), kbatch_curr); gptr->SetKBatchSize(argument_ptr.get(), kbatch_curr);
DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get()));
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_ptr->SetWorkSpacePointer(argument_ptr.get(),
gemm_desc_workspace.GetDeviceBuffer()); gemm_desc_workspace.GetDeviceBuffer());
// std::cout << "WorkspacePointer set!" << std::endl;
if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
c_device_buf[i]->SetZero(); c_device_buf[i]->SetZero();
// invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 1}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
// std::cout << ">>>>>GPU Run end!" << std::endl;
if(do_verification) if(do_verification)
{ {
...@@ -313,13 +310,12 @@ bool profile_ggemm_multid_splitk(int do_verification, ...@@ -313,13 +310,12 @@ bool profile_ggemm_multid_splitk(int do_verification,
std::cout << ">>>>>CPU verification end!" << std::endl; std::cout << ">>>>>CPU verification end!" << std::endl;
} }
if(time_kernel) if(time_kernel)
{ {
std::cout << ">>>>>GPU time profiling start!" << std::endl; std::cout << ">>>>>GPU time profiling start!" << std::endl;
float avg_time = invoker_ptr->Run( float avg_time = invoker_ptr->Run(
// argument_ptr.get(), StreamConfig{nullptr, time_kernel, 1, 5, 30}); argument_ptr.get(),
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 1, 0, 1}); StreamConfig{nullptr, time_kernel, 0, warmup_iter, kernel_iter});
std::size_t flop = 0, num_btype = 0; std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
......
...@@ -70,6 +70,8 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[]) ...@@ -70,6 +70,8 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
<< "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " << "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n" "64,64 64,64 128,128)\n"
<< "arg15: kbatch value (default 4)\n" << "arg15: kbatch value (default 4)\n"
<< "arg16: warm-up iterations (default 1)\n"
<< "arg17: kernel repeat iterations (default 10)\n"
<< std::endl; << std::endl;
exit(1); exit(1);
...@@ -90,6 +92,8 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[]) ...@@ -90,6 +92,8 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
const auto StrideBs = argToIntArray(argv[12]); const auto StrideBs = argToIntArray(argv[12]);
const auto StrideCs = argToIntArray(argv[13]); const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1; const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
const int warmup_iter = argc == 16 ? std::stoi(argv[15]) : 1;
const int kernel_iter = argc == 17 ? std::stoi(argv[16]) : 10;
#ifdef CK_ENABLE_FP16 #ifdef CK_ENABLE_FP16
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
...@@ -110,7 +114,9 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[]) ...@@ -110,7 +114,9 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs,
kbatch); kbatch,
warmup_iter,
kernel_iter);
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
...@@ -131,7 +137,9 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[]) ...@@ -131,7 +137,9 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
StrideAs, StrideAs,
StrideBs, StrideBs,
StrideCs, StrideCs,
kbatch); kbatch,
warmup_iter,
kernel_iter);
} }
else else
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <memory> #include <memory>
#include <vector> #include <vector>
...@@ -150,18 +150,16 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p ...@@ -150,18 +150,16 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
if(b2c_tile_map.IsFirstKSplitBlock()) if(b2c_tile_map.IsFirstKSplitBlock())
{ {
// Wait untill all other blocks for this [M,N] tile store their results. // Wait untill all other blocks for this [M,N] tile store their results.
work_scheduler.WaitForNeighbours(k_batch, output_tile_idx, output_tile_idx_offset); index_t neighbour_count = work_scheduler.WaitForNeighbours(
k_batch, b2c_tile_map.GetTileKIdx(), output_tile_idx, output_tile_idx_offset);
// Accumulate partial results. We can have different # of workgroups to reduce, thus we // Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value. // read actual flag value.
const uint32_t flag_v = __builtin_amdgcn_readfirstlane( for(index_t i = 1; i <= neighbour_count; ++i)
work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
for(uint32_t i = 1; i < flag_v; ++i)
{ {
partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock + partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock +
i * MPerBlock * NPerBlock + get_thread_local_1d_id()]; i * MPerBlock * NPerBlock + get_thread_local_1d_id()];
} }
// Signal waiting blocks that they can start use their workspace. // Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset); work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
...@@ -284,11 +282,10 @@ struct GroupedGemmStridedTileLoopReduce ...@@ -284,11 +282,10 @@ struct GroupedGemmStridedTileLoopReduce
DeviceMem gemm_workspace, gemm_flags; DeviceMem gemm_workspace, gemm_flags;
// const index_t tiles_per_block = (tile_count + grid_size - 1) / grid_size; const index_t tiles_per_block = (tile_count + grid_size - 1) / grid_size;
// This is the number of MN-output tiles which we cover with workgroups. // This is the number of MN-output tiles which we cover with workgroups.
// We launch k_batch / tiles_per_block workgroups for each output tile. // We launch k_batch / tiles_per_block workgroups for each output tile.
// const index_t flag_count = (grid_size * tiles_per_block + k_batch - 1) / k_batch; const index_t flag_count = (grid_size * tiles_per_block + k_batch - 1) / k_batch;
const index_t flag_count = tile_count / k_batch;
gemm_workspace.Realloc(grid_size * MPerBlock * NPerBlock * sizeof(float)); gemm_workspace.Realloc(grid_size * MPerBlock * NPerBlock * sizeof(float));
gemm_flags.Realloc(flag_count * sizeof(uint32_t)); gemm_flags.Realloc(flag_count * sizeof(uint32_t));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment