Commit 593c2909 authored by Jing Zhang's avatar Jing Zhang Committed by root
Browse files

add simple kernel arg

parent 6819fc4c
...@@ -31,18 +31,20 @@ __global__ void ...@@ -31,18 +31,20 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_xdl_splitk(const void* gemm_desc_const,
const index_t group_count) const index_t group_count,
const index_t k_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(gemm_desc_const);
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
#if 0
index_t left = 0; index_t left = 0;
index_t right = group_count; index_t right = group_count;
index_t group_id = index_t((left + right) / 2); index_t group_id = index_t((left + right) / 2);
...@@ -60,18 +62,35 @@ __global__ void ...@@ -60,18 +62,35 @@ __global__ void
} }
group_id = index_t((left + right) / 2); group_id = index_t((left + right) / 2);
} }
#else
if(block_id >= gemm_desc_ptr[group_count - 1].block_end_)
return;
index_t group_id = 0;
for(; group_id < group_count; group_id++)
{
if(block_id >= gemm_desc_ptr[group_id].block_start_ &&
block_id < gemm_desc_ptr[group_id].block_end_)
{
break;
}
}
#endif
const auto M = gemm_desc_ptr[group_id].karg_.M; const auto p_a_grid = gemm_desc_ptr[group_id].p_a_grid;
const auto N = gemm_desc_ptr[group_id].karg_.N; const auto p_b_grid = gemm_desc_ptr[group_id].p_b_grid;
const auto K = gemm_desc_ptr[group_id].karg_.K; const auto p_c_grid = gemm_desc_ptr[group_id].p_c_grid;
const auto StrideA = gemm_desc_ptr[group_id].karg_.StrideA; const auto M = gemm_desc_ptr[group_id].M;
const auto StrideB = gemm_desc_ptr[group_id].karg_.StrideB; const auto N = gemm_desc_ptr[group_id].N;
const auto StrideC = gemm_desc_ptr[group_id].karg_.StrideC; const auto K = gemm_desc_ptr[group_id].K;
const auto MPadded = gemm_desc_ptr[group_id].karg_.MPadded; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
const auto NPadded = gemm_desc_ptr[group_id].karg_.NPadded; const auto StrideB = gemm_desc_ptr[group_id].StrideB;
const auto KPadded = gemm_desc_ptr[group_id].karg_.KPadded; const auto StrideC = gemm_desc_ptr[group_id].StrideC;
const auto K0 = gemm_desc_ptr[group_id].karg_.K0;
const auto k_batch = gemm_desc_ptr[group_id].karg_.k_batch; const auto MPadded = GridwiseGemm::CalculateMPadded(M);
const auto NPadded = GridwiseGemm::CalculateNPadded(N);
const auto KPadded = GridwiseGemm::CalculateKPadded(K, k_batch);
const auto K0 = GridwiseGemm::CalculateK0(K, k_batch);
static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock(); static constexpr index_t MPerBlock = GridwiseGemm::GetMPerBlock();
static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock(); static constexpr index_t NPerBlock = GridwiseGemm::GetNPerBlock();
...@@ -88,9 +107,9 @@ __global__ void ...@@ -88,9 +107,9 @@ __global__ void
GroupedGemmBlock2ETileMap(local_b2c_tile_map, gemm_desc_ptr[group_id].block_start_); GroupedGemmBlock2ETileMap(local_b2c_tile_map, gemm_desc_ptr[group_id].block_start_);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_.p_a_grid, p_a_grid,
gemm_desc_ptr[group_id].karg_.p_b_grid, p_b_grid,
gemm_desc_ptr[group_id].karg_.p_c_grid, p_c_grid,
M, M,
N, N,
K, K,
...@@ -277,20 +296,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -277,20 +296,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
gemm_kernel_args_.reserve(group_count_); gemm_kernel_args_.reserve(group_count_);
skipped_group_count_ = 0;
for(std::size_t i = 0; i < gemm_descs.size(); ++i) for(std::size_t i = 0; i < gemm_descs.size(); ++i)
{ {
const index_t M = gemm_descs[i].M_; const index_t M = gemm_descs[i].M_;
const index_t N = gemm_descs[i].N_; const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_; const index_t K = gemm_descs[i].K_;
if(M == 0)
{
skipped_group_count_++;
continue;
}
const index_t stride_a = gemm_descs[i].stride_A_; const index_t stride_a = gemm_descs[i].stride_A_;
const index_t stride_b = gemm_descs[i].stride_B_; const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_; const index_t stride_c = gemm_descs[i].stride_C_;
...@@ -379,7 +390,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -379,7 +390,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
// private: // private:
index_t K_BATCH; index_t K_BATCH;
index_t group_count_; index_t group_count_;
index_t skipped_group_count_;
std::vector<GemmTransKernelArg> gemm_kernel_args_; std::vector<GemmTransKernelArg> gemm_kernel_args_;
index_t grid_size_; index_t grid_size_;
...@@ -388,8 +398,28 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -388,8 +398,28 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
struct SimpleGemmArgument
{
const ADataType* p_a_grid;
const BDataType* p_b_grid;
EDataType* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t block_start_;
index_t block_end_;
};
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
std::vector<SimpleGemmArgument> simple_gemm_kernel_args_;
simple_gemm_kernel_args_.reserve(arg.gemm_kernel_args_.size());
index_t K0 = arg.gemm_kernel_args_[0].karg_.K0; index_t K0 = arg.gemm_kernel_args_[0].karg_.K0;
bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1; bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1;
bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
...@@ -434,12 +464,26 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -434,12 +464,26 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
<< " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; << " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
simple_gemm_kernel_args_.push_back({karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg.M,
karg.N,
karg.K,
karg.StrideA,
karg.StrideB,
karg.StrideC,
arg.gemm_kernel_args_[i].block_start_,
arg.gemm_kernel_args_[i].block_end_});
} }
using GemmArgumentType = SimpleGemmArgument;
hip_check_error( hip_check_error(
hipMemcpyWithStream(arg.p_workspace_, hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_kernel_args_.data(), simple_gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), simple_gemm_kernel_args_.size() * sizeof(GemmArgumentType),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
...@@ -456,14 +500,14 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -456,14 +500,14 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
} }
} }
ave_time = ave_time = launch_and_time_kernel(stream_config,
launch_and_time_kernel(stream_config, kernel,
kernel, dim3(arg.grid_size_),
dim3(arg.grid_size_), dim3(BlockSize),
dim3(BlockSize), 0,
0, arg.p_workspace_,
cast_pointer_to_constant_address_space(arg.p_workspace_), arg.gemm_kernel_args_.size(),
arg.gemm_kernel_args_.size()); arg.gemm_kernel_args_[0].karg_.k_batch);
}; };
if(all_have_main_k0_block_loop) if(all_have_main_k0_block_loop)
...@@ -472,7 +516,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -472,7 +516,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{ {
const auto kernel = const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm, kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg, GemmArgumentType,
true, true,
InMemoryDataOperationEnum::AtomicAdd>; InMemoryDataOperationEnum::AtomicAdd>;
...@@ -482,7 +526,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -482,7 +526,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{ {
const auto kernel = const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm, kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg, GemmArgumentType,
true, true,
InMemoryDataOperationEnum::Set>; InMemoryDataOperationEnum::Set>;
...@@ -495,7 +539,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -495,7 +539,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{ {
const auto kernel = const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm, kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg, GemmArgumentType,
false, false,
InMemoryDataOperationEnum::AtomicAdd>; InMemoryDataOperationEnum::AtomicAdd>;
...@@ -505,7 +549,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -505,7 +549,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{ {
const auto kernel = const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm, kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg, GemmArgumentType,
false, false,
InMemoryDataOperationEnum::Set>; InMemoryDataOperationEnum::Set>;
...@@ -532,13 +576,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -532,13 +576,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) + if(ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) != arg.group_count_)
arg.skipped_group_count_) != arg.group_count_)
{ {
#if DEBUG_LOG #if DEBUG_LOG
std::cout << "The group count is not equal to sum of skipped groups " std::cout << "The group count is not equal to kernel args size!" << std::endl;
"and kernel args size!"
<< std::endl;
#endif // DEBUG_LOG #endif // DEBUG_LOG
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment