Commit 220f40c9 authored by rtmadduri's avatar rtmadduri
Browse files

fix gridsize calculations

parent 1c1da090
...@@ -200,8 +200,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -200,8 +200,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
ComputeTypeA, ComputeTypeA,
ComputeTypeB>; ComputeTypeB>;
using Block2ETileMap = typename GridwiseGemm::Block2CTileMap; using Block2ETileMap =
BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>; using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
using KernelArgument = typename GridwiseGemm::Argument; using KernelArgument = typename GridwiseGemm::Argument;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...@@ -209,16 +210,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -209,16 +210,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
struct GemmTransKernelArg struct GemmTransKernelArg
{ {
KernelArgument karg_; KernelArgument karg_;
GroupedGemmBlock2ETileMap block_2_ctile_map_; // GroupedGemmBlock2ETileMap block_2_ctile_map_;
index_t block_start_, block_end_; index_t block_start_, block_end_;
GemmTransKernelArg() = default; GemmTransKernelArg() = default;
GemmTransKernelArg(KernelArgument&& karg, GemmTransKernelArg(KernelArgument&& karg,
GroupedGemmBlock2ETileMap&& b2c_map, // GroupedGemmBlock2ETileMap&& b2c_map,
index_t block_start, index_t block_start,
index_t block_end) index_t block_end)
: karg_{karg}, : karg_{karg},
block_2_ctile_map_{b2c_map}, // block_2_ctile_map_{b2c_map},
block_start_{block_start}, block_start_{block_start},
block_end_{block_end} block_end_{block_end}
{ {
...@@ -277,15 +278,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -277,15 +278,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const index_t stride_b = gemm_descs[i].stride_B_; const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_; const index_t stride_c = gemm_descs[i].stride_C_;
// const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
// const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(M, N, K_BATCH); std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(M, N, K_BATCH);
const auto local_b2c_tile_map = Block2ETileMap{gdx, gdy, gdz}; const auto local_b2c_tile_map = Block2ETileMap{gdx, gdy, gdz};
// const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(M, N);
const index_t grid_size_grp = gdx * gdy * gdz; // const index_t grid_size_grp = gdx * gdy * gdz;
const index_t block_start = grid_size_; const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp; const index_t block_end = grid_size_ + grid_size_grp;
...@@ -293,8 +291,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -293,8 +291,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
// block-to-e-tile map // block-to-e-tile map
auto grouped_block_2_ctile_map = // auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); // GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
KernelArgument karg{type_convert<const ADataType*>(p_a_grid[i]), KernelArgument karg{type_convert<const ADataType*>(p_a_grid[i]),
type_convert<const BDataType*>(p_b_grid[i]), type_convert<const BDataType*>(p_b_grid[i]),
...@@ -307,8 +305,13 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -307,8 +305,13 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
stride_c, stride_c,
K_BATCH}; K_BATCH};
// gemm_kernel_args_.emplace_back(
// std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
gemm_kernel_args_.emplace_back( gemm_kernel_args_.emplace_back(
std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); std::move(karg), block_start, block_end);
} }
} }
...@@ -334,19 +337,19 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -334,19 +337,19 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg.M, karg.N, karg.KBatch); std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg.M, karg.N, karg.KBatch);
const auto local_b2c_tile_map = Block2ETileMap{gdx, gdy, gdz}; const auto local_b2c_tile_map = Block2ETileMap{gdx, gdy, gdz};
// const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(karg.M, karg.N);
const index_t grid_size_grp = gdx * gdy * gdz; // const index_t grid_size_grp = gdx * gdy * gdz;
const index_t block_start = grid_size_; const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp; const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
auto grouped_block_2_ctile_map = // auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); // GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KBatch = K_BATCH; karg.KBatch = K_BATCH;
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; // gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_[i].block_start_ = block_start; gemm_kernel_args_[i].block_start_ = block_start;
gemm_kernel_args_[i].block_end_ = block_end; gemm_kernel_args_[i].block_end_ = block_end;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment