Commit 0ce9cacf authored by Adam Osewski's avatar Adam Osewski
Browse files

Get back to use constant memory for gemm descriptors.

parent 3644f0ec
...@@ -53,7 +53,7 @@ __global__ void ...@@ -53,7 +53,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdl_splitk(const void* gemm_desc, kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t tile_count, const index_t tile_count,
const index_t k_batch) const index_t k_batch)
{ {
...@@ -63,9 +63,10 @@ __global__ void ...@@ -63,9 +63,10 @@ __global__ void
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
index_t tile_id = get_block_1d_id(); index_t tile_id = get_block_1d_id();
const index_t grid_size = get_grid_size(); const index_t grid_size = get_grid_size();
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(gemm_desc); const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock(); static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock();
static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock(); static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock();
...@@ -144,7 +145,7 @@ __global__ void ...@@ -144,7 +145,7 @@ __global__ void
} }
#else #else
ignore = gemm_desc; ignore = gemm_descs_const;
ignore = tile_count; ignore = tile_count;
ignore = k_batch; ignore = k_batch;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
...@@ -502,7 +503,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -502,7 +503,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const auto& gemm_arg = arg.gemm_kernel_args_[i]; const auto& gemm_arg = arg.gemm_kernel_args_[i];
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
gemm_arg.Print(); // gemm_arg.Print();
} }
// Currently all groups use same kbatch value. // Currently all groups use same kbatch value.
...@@ -613,7 +614,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -613,7 +614,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
dim3(ck::math::min(arg.grid_size_, max_occupancy_grid_size)), dim3(ck::math::min(arg.grid_size_, max_occupancy_grid_size)),
dim3(BlockSize), dim3(BlockSize),
0, 0,
dev_gemm_args, cast_pointer_to_constant_address_space(dev_gemm_args),
arg.grid_size_, arg.grid_size_,
arg.K_BATCH); arg.K_BATCH);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment