Commit 7e71ea99 authored by Adam Osewski's avatar Adam Osewski
Browse files

Commit debug WIP for sharing.

parent 734df790
......@@ -213,7 +213,7 @@
#define CK_WORKAROUND_SWDEV_388832 1
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0
#define DEBUG_LOG 1
// denorm test fix, required to work around dissue
#ifndef CK_WORKAROUND_DENORM_FIX
......
......@@ -103,14 +103,17 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
block_dim.y,
block_dim.z);
printf("Warm up 1 time\n");
printf("Warm up %d times\n", stream_config.cold_niters_);
#endif
// warm up
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
for(int i = 0; i < stream_config.cold_niters_; ++i)
{
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
}
const int nrepeat = 10;
const int nrepeat = stream_config.nrepeat_;
#if DEBUG_LOG
printf("Start running %d times...\n", nrepeat);
#endif
......
......@@ -68,15 +68,15 @@ __global__ void
void* const __restrict__ p_workspace,
const index_t tile_count,
const index_t k_batch,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
[[maybe_unused]] const AElementwiseOperation a_element_op,
[[maybe_unused]] const BElementwiseOperation b_element_op,
[[maybe_unused]] const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
[[maybe_unused]] __shared__ uint8_t p_shared[shared_size];
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
......@@ -105,6 +105,12 @@ __global__ void
index_t gemm_tile_id_end = grid_size_grp;
auto gridwise_gemm = GridwiseGemm();
[[maybe_unused]] auto is_thread_local_1d_id_idx = [](auto... Ids) -> bool
{
const auto tid = get_thread_local_1d_id();
return ((tid == Ids) || ... );
};
do
{
// Find corresponding GEMM group for our tile
......@@ -123,12 +129,12 @@ __global__ void
gemm_tile_id_end = offset + grid_size_grp;
}
const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid);
const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid);
[[maybe_unused]] const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid);
[[maybe_unused]] const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid);
const auto K = gemm_desc_ptr[group_id].K;
const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto StrideB = gemm_desc_ptr[group_id].StrideB;
[[maybe_unused]] const auto K = gemm_desc_ptr[group_id].K;
[[maybe_unused]] const auto StrideA = gemm_desc_ptr[group_id].StrideA;
[[maybe_unused]] const auto StrideB = gemm_desc_ptr[group_id].StrideB;
auto& results_buffer = gridwise_gemm.GetCThreadBuffer();
b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset);
......@@ -137,21 +143,32 @@ __global__ void
// Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile
// TODO: change desc so that few K-tiles will be done in single GEMM.
// {
// if (is_thread_local_1d_id_idx(0))
// {
// printf("bid: %d, group: %d, accumulate tile id (M,N,K): [%d, %d, %d] \n",
// static_cast<index_t>(blockIdx.x),
// group_id,
// b2c_tile_map.GetTileMIdx(),
// b2c_tile_map.GetTileNIdx(),
// b2c_tile_map.GetTileKIdx());
// }
// }
do
{
// just accumulate results in registers!
gridwise_gemm.template RunGEMM<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
static_cast<void*>(p_shared),
a_element_op,
b_element_op,
M,
N,
K,
StrideA,
StrideB,
k_batch,
b2c_tile_map);
// gridwise_gemm.template RunGEMM<HasMainKBlockLoop>(p_a_grid,
// p_b_grid,
// static_cast<void*>(p_shared),
// a_element_op,
// b_element_op,
// M,
// N,
// K,
// StrideA,
// StrideB,
// k_batch,
// b2c_tile_map);
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
......@@ -167,51 +184,122 @@ __global__ void
work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset);
// {
// // const uint32_t flag_v2 = __builtin_amdgcn_readfirstlane(
// // work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
if (is_thread_local_1d_id_idx(0))
{
printf("bid: %d, group: %d, FlagFInished \n",
static_cast<index_t>(blockIdx.x),
group_id);
// printf("bid: %d, group: %d, FlagFInished flag_v[%u]: %u\n",
// static_cast<index_t>(blockIdx.x),
// group_id)
// work_scheduler.GetWorkgroupFlagIdx(k_batch, output_tile_idx, output_tile_idx_offset),
// flag_v2);
}
// }
// The workgroup which processed first K tile accumulates results and stores to GMEM
if(b2c_tile_map.IsFirstKSplitBlock())
{
if (is_thread_local_1d_id_idx(0))
{
printf("bid: %d, group: %d, Will wait for neighbours... \n",
static_cast<index_t>(blockIdx.x),
group_id);
}
// Wait untill all other blocks for this [M,N] tile store their results.
work_scheduler.WaitForNeighbours(k_batch, output_tile_idx, output_tile_idx_offset);
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value.
const uint32_t flag_v = __builtin_amdgcn_readfirstlane(
[[maybe_unused]] const uint32_t flag_v = __builtin_amdgcn_readfirstlane(
work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
// {
// if (is_thread_local_1d_id_idx(0))
// {
// printf("bid: %d, group: %d, WaitForNeighbours flag_v[%u]: %u\n",
// static_cast<index_t>(blockIdx.x),
// group_id,
// work_scheduler.GetWorkgroupFlagIdx(k_batch, output_tile_idx, output_tile_idx_offset),
// static_cast<index_t>(blockIdx.x));
// // flag_v);
// }
// }
// using CThreadBuffer = remove_cvref_t<decltype(results_buffer)>;
// constexpr index_t n_v = CThreadBuffer::num_of_v_.value;
// constexpr index_t s_per_v = CThreadBuffer::s_per_v.value;
// static_for<0, n_v, 1>{}([&](auto v) {
// static_for<0, s_per_v, 1>{}([&](auto s) {
// // printf("bid: %d; tid: %d; [Partial results] c_thread_buff[%d, %d]:
// // %f\n",
// // static_cast<index_t>(blockIdx.x),
// // static_cast<index_t>(threadIdx.x),
// // v.value,
// // s.value,
// // static_cast<float>(results_buffer[v * Number<s_per_v>{} + s])
// // );
// results_buffer(v * Number<s_per_v>{} + s) = threadIdx.x * v + s;
// });
// });
// Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile.
if(flag_v > 1)
gridwise_gemm.AccumulatePartials(p_workspace, flag_v);
// if(flag_v > 1)
// gridwise_gemm.AccumulatePartials(p_workspace, flag_v);
if (is_thread_local_1d_id_idx(0))
{
printf("bid: %d, group: %d, Reset flag \n",
static_cast<index_t>(blockIdx.x),
group_id);
}
// Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
const auto p_e_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_e_grid);
const auto stride_e = gemm_desc_ptr[group_id].StrideE;
const auto stride_ds = gemm_desc_ptr[group_id].StrideDs;
constexpr auto NumDTensor = DsDataType::Size();
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
DsGridPointer p_ds_grid;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
});
gridwise_gemm.template RunWrite(p_ds_grid,
p_e_grid,
static_cast<void*>(p_shared),
M,
N,
stride_ds,
stride_e,
cde_element_op,
b2c_tile_map);
// const auto p_e_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_e_grid);
// const auto stride_e = gemm_desc_ptr[group_id].StrideE;
// const auto stride_ds = gemm_desc_ptr[group_id].StrideDs;
// constexpr auto NumDTensor = DsDataType::Size();
// using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
// DsGridPointer p_ds_grid;
// static_for<0, NumDTensor, 1>{}([&](auto i) {
// using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
// });
// gridwise_gemm.template RunWrite(p_ds_grid,
// p_e_grid,
// static_cast<void*>(p_shared),
// M,
// N,
// stride_ds,
// stride_e,
// cde_element_op,
// b2c_tile_map);
}
else if(work_scheduler.HasTile())
{
{
// const uint32_t flag_v2 = __builtin_amdgcn_readfirstlane(
const uint32_t flag_v2 = work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset);
if (is_thread_local_1d_id_idx(0))
{
printf("bid: %d, group: %d, Waiting for Reduction flag_v[%u]: %u\n",
static_cast<index_t>(blockIdx.x),
group_id,
work_scheduler.GetWorkgroupFlagIdx(k_batch, output_tile_idx, output_tile_idx_offset),
// static_cast<index_t>(blockIdx.x));
flag_v2);
}
}
work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset);
}
} while(work_scheduler.HasTile());
......@@ -751,7 +839,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void* p_flags = reinterpret_cast<char*>(dev_gemm_workspace) +
Block2ETileMapKSplit::GetAccWorkspaceSize(
sizeof(typename GridwiseGemm::AccType), grid_size);
std::size_t flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
// std::size_t flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
std::size_t flag_count = arg.tile_count_ / arg.K_BATCH;
if(stream_config.log_level_ > 0)
{
......@@ -987,7 +1076,14 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
arg.gpu_cu_count_ * std::min(arg.occupancy_num_blocks_, KernelConfig::CU_BLOCKS);
int grid_size = std::min(arg.tile_count_, occ_grid_size);
int tiles_per_block = (arg.tile_count_ + grid_size - 1) / grid_size;
int flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
if(arg.tile_count_ > occ_grid_size &&
grid_size * tiles_per_block > arg.tile_count_)
{
grid_size = (arg.tile_count_ + tiles_per_block - 1) / tiles_per_block;
}
// int flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
int flag_count = arg.tile_count_ / arg.K_BATCH;
// This would be the maximum needed workspace size. Since actual grid size, which determines
// the amount of workspace bytes needed, may be less due to the number of available CUs in
......
......@@ -106,6 +106,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
template <index_t... Ids>
__device__ static bool is_thread_local_1d_id_idx()
{
const auto tid = get_thread_local_1d_id();
return ((tid == Ids) || ...);
}
public:
using AccType = AccDataType;
......@@ -906,6 +913,32 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
Sequence<6>{},
Sequence<7>{}));
// if (is_thread_local_1d_id_idx<0>())
// {
// // printf("bid: %d; tid: %d; [Store Partials] c_block_desc:[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// // static_cast<index_t>(blockIdx.x),
// // static_cast<index_t>(threadIdx.x),
// // M0.value,
// // N0.value,
// // M1.value,
// // N1.value,
// // M2.value,
// // M3.value,
// // M4.value,
// // N2.value);
// printf("bid: %d; tid: %d; [Store Partials] wrkspace_desc:[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7).value);
// }
auto p_workspace_grid = reinterpret_cast<AccDataType*>(p_workspace);
auto w_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_workspace_grid, workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
......@@ -963,11 +996,33 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// if (is_thread_local_1d_id_idx<0, 64, 223>())
// {
// printf("[StorePartials] bid: %d, tid: %d: dst origin idx[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// (static_cast<index_t>(blockIdx.x)) * MXdlPerWave,
// n_thread_data_on_block_idx[I0],
// m_thread_data_on_block_idx[I1],
// n_thread_data_on_block_idx[I1],
// m_thread_data_on_block_idx[I2],
// m_thread_data_on_block_idx[I3],
// m_thread_data_on_block_idx[I4],
// n_thread_data_on_block_idx[I2]);
// }
c_thread_copy_vgpr_to_gmem.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
w_grid_buf);
if (is_thread_local_1d_id_idx<0>())
{
printf("[StorePartials] done. bid: %d, tid: %d\n",
static_cast<index_t>(blockIdx.x),
static_cast<index_t>(threadIdx.x));
}
}
__device__ void AccumulatePartials(void* __restrict__ p_workspace, uint32_t reduce_count)
......
......@@ -51,7 +51,7 @@ class StridedReductionTileLoop
{
tile_id_++;
block_tile_idx_++;
return tile_id_ < tile_count_ && block_tile_idx_ < tiles_per_block_;
return HasTile();
}
__device__ index_t GetFlagCount(index_t k_tiles) const
......@@ -75,11 +75,12 @@ class StridedReductionTileLoop
///
/// @return The workgroup flag index.
///
__device__ uint32_t GetWorkgroupFlagIdx(index_t k_tiles,
__device__ uint32_t GetWorkgroupFlagIdx([[maybe_unused]] index_t k_tiles,
index_t output_tile_idx,
index_t output_tile_idx_offset) const
{
return (output_tile_idx + output_tile_idx_offset) % GetFlagCount(k_tiles);
// return (output_tile_idx + output_tile_idx_offset) % GetFlagCount(k_tiles);
return output_tile_idx + output_tile_idx_offset;
}
///
......@@ -92,7 +93,7 @@ class StridedReductionTileLoop
__device__ void
FlagFinished(index_t k_tiles, index_t output_tile_idx, index_t output_tile_idx_offset)
{
const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
/* [[maybe_unused]] */const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
finished_block_flags_.inc(fidx);
}
......@@ -111,8 +112,10 @@ class StridedReductionTileLoop
// We use < because for some cases we may have +1 more workgroups per dim.
// Ie when k_tiles = 5, tiles_per_block = 3.
finished_block_flags_.wait_lt(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset),
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset),
workgroups_per_dim);
// [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
}
///
......@@ -128,6 +131,8 @@ class StridedReductionTileLoop
// Wait untill the counter has been reset.
finished_block_flags_.wait_eq(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset), 0);
// [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
}
///
......@@ -141,6 +146,8 @@ class StridedReductionTileLoop
{
finished_block_flags_.reset(
GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset));
// [[maybe_unused]] const auto fidx = GetWorkgroupFlagIdx(k_tiles, output_tile_idx, output_tile_idx_offset);
}
///
......
......@@ -16,96 +16,96 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Row,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Col,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
// Row,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
// Col,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Col,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Row,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Col,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
......@@ -120,31 +120,31 @@ void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F8,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F8,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Row,
// Empty_Tuple,
// Row,
// F16,
// F8,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
// void add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Row,
// Empty_Tuple,
// Row,
// F8,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
template <typename ALayout,
typename BLayout,
......@@ -186,48 +186,48 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
// add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
// add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
// add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
// add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
// add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
// op_ptrs);
}
// else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
// }
// else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
// }
}
// else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
// is_same_v<EDataType, half_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(op_ptrs);
// }
// }
// else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
// is_same_v<EDataType, half_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
// }
// }
return op_ptrs;
}
};
......
......@@ -17,18 +17,18 @@ namespace device {
namespace instance {
// MultiD version
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances(
// std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
// Col,
// Empty_Tuple,
// Row,
// F16,
// F16,
// Empty_Tuple,
// F16,
// PassThrough,
// PassThrough,
// PassThrough>>>& instances);
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
......@@ -93,8 +93,8 @@ struct DeviceOperationInstanceFactory<
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances(
op_ptrs);
// add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances(
// op_ptrs);
}
}
return op_ptrs;
......
add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
# device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
# device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
# device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
# device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
# device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
# device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
# device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
# device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
# device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
)
add_instance_library(device_grouped_gemm_multiple_d_instance
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
# device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
)
......@@ -219,6 +219,8 @@ bool profile_ggemm_multid_splitk(int do_verification,
// profile device GEMM instances
for(auto& gemm_ptr : op_ptrs)
{
std::cout << "Running instance: " << gemm_ptr->GetTypeString() << std::endl;
auto gptr = dynamic_cast<DeviceOp*>(gemm_ptr.get());
auto argument_ptr = gemm_ptr->MakeArgumentPointer(
......@@ -247,20 +249,24 @@ bool profile_ggemm_multid_splitk(int do_verification,
for(std::size_t j = 0; j < kbatch_list.size(); j++)
{
auto kbatch_curr = kbatch_list[j];
// std::cout << ">>> kbatch: " << kbatch_curr << std::endl;
gptr->SetKBatchSize(argument_ptr.get(), kbatch_curr);
DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get()));
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(),
gemm_desc_workspace.GetDeviceBuffer());
// std::cout << "WorkspacePointer set!" << std::endl;
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
for(std::size_t i = 0; i < gemm_descs.size(); i++)
c_device_buf[i]->SetZero();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
// invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 1});
// std::cout << ">>>>>GPU Run end!" << std::endl;
if(do_verification)
{
......@@ -304,12 +310,16 @@ bool profile_ggemm_multid_splitk(int do_verification,
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
pass = pass && instance_pass;
std::cout << ">>>>>CPU verification end!" << std::endl;
}
if(time_kernel)
{
float avg_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::cout << ">>>>>GPU time profiling start!" << std::endl;
float avg_time = invoker_ptr->Run(
// argument_ptr.get(), StreamConfig{nullptr, time_kernel, 1, 5, 30});
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 1, 0, 1});
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
......@@ -335,6 +345,7 @@ bool profile_ggemm_multid_splitk(int do_verification,
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
// std::cout << ">>>>>GPU time profiling end!" << std::endl;
}
}
else
......
# ckProfiler
set(PROFILER_SOURCES
profiler.cpp
profile_gemm.cpp
profile_gemm_splitk.cpp
profile_gemm_bias_add_reduce.cpp
profile_gemm_add_multiply.cpp
profile_gemm_multiply_add.cpp
profile_gemm_reduce.cpp
profile_batched_gemm.cpp
profile_batched_gemm_reduce.cpp
profile_conv_fwd.cpp
profile_conv_fwd_bias_relu.cpp
profile_conv_fwd_bias_relu_add.cpp
profile_conv_bwd_data.cpp
profile_grouped_conv_fwd.cpp
profile_grouped_conv_bwd_weight.cpp
profile_reduce.cpp
profile_groupnorm_bwd_data.cpp
profile_groupnorm_fwd.cpp
profile_layernorm_bwd_data.cpp
profile_layernorm_fwd.cpp
profile_max_pool3d_fwd.cpp
profile_avg_pool3d_bwd.cpp
profile_max_pool3d_bwd.cpp
profile_softmax.cpp
profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp
profile_grouped_conv_bwd_data.cpp
profile_conv_tensor_rearrange.cpp
# profile_gemm.cpp
# profile_gemm_splitk.cpp
# profile_gemm_bias_add_reduce.cpp
# profile_gemm_add_multiply.cpp
# profile_gemm_multiply_add.cpp
# profile_gemm_reduce.cpp
# profile_batched_gemm.cpp
# profile_batched_gemm_reduce.cpp
# profile_conv_fwd.cpp
# profile_conv_fwd_bias_relu.cpp
# profile_conv_fwd_bias_relu_add.cpp
# profile_conv_bwd_data.cpp
# profile_grouped_conv_fwd.cpp
# profile_grouped_conv_bwd_weight.cpp
# profile_reduce.cpp
# profile_groupnorm_bwd_data.cpp
# profile_groupnorm_fwd.cpp
# profile_layernorm_bwd_data.cpp
# profile_layernorm_fwd.cpp
# profile_max_pool3d_fwd.cpp
# profile_avg_pool3d_bwd.cpp
# profile_max_pool3d_bwd.cpp
# profile_softmax.cpp
# profile_batchnorm_fwd.cpp
# profile_batchnorm_bwd.cpp
# profile_batchnorm_infer.cpp
# profile_grouped_conv_bwd_data.cpp
# profile_conv_tensor_rearrange.cpp
)
if(DL_KERNELS)
......@@ -36,21 +36,22 @@ if(DL_KERNELS)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp)
# list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
# list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
# list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiple_d_splitk.cpp)
endif()
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
# list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
# list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
endif()
set(PROFILER_EXECUTABLE ckProfiler)
......@@ -59,42 +60,42 @@ add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES})
target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance)
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
endif()
......@@ -104,16 +105,17 @@ if(DL_KERNELS)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_multiple_d_instance)
endif()
rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler)
......@@ -154,9 +154,9 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value.
const index_t flag_v = __builtin_amdgcn_readfirstlane(
const uint32_t flag_v = __builtin_amdgcn_readfirstlane(
work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
for(index_t i = 1; i < flag_v; ++i)
for(uint32_t i = 1; i < flag_v; ++i)
{
partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock +
i * MPerBlock * NPerBlock + get_thread_local_1d_id()];
......@@ -174,7 +174,7 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
p_C[(C_m_tile_offset + C_thread_tile_m_idx) * stride_c + C_n_tile_offset +
C_thread_tile_n_idx] = partial_result;
}
else
else if(work_scheduler.HasTile())
{
work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset);
}
......@@ -284,10 +284,11 @@ struct GroupedGemmStridedTileLoopReduce
DeviceMem gemm_workspace, gemm_flags;
const index_t tiles_per_block = (tile_count + grid_size - 1) / grid_size;
// const index_t tiles_per_block = (tile_count + grid_size - 1) / grid_size;
// This is the number of MN-output tiles which we cover with workgroups.
// We launch k_batch / tiles_per_block workgroups for each output tile.
const index_t flag_count = (grid_size * tiles_per_block + k_batch - 1) / k_batch;
// const index_t flag_count = (grid_size * tiles_per_block + k_batch - 1) / k_batch;
const index_t flag_count = tile_count / k_batch;
gemm_workspace.Realloc(grid_size * MPerBlock * NPerBlock * sizeof(float));
gemm_flags.Realloc(flag_count * sizeof(uint32_t));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment