Commit bf73d297 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fixes

parent a5e9069f
......@@ -28,18 +28,14 @@ template <typename GridwiseGemm,
typename GemmDesc,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
typename CDEElementwiseOperation = ck::tensor_operation::element_wise::PassThrough>
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
const index_t group_count)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
......@@ -68,13 +64,19 @@ __global__ void
group_id = index_t((left + right) / 2);
}
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_,
static_cast<void*>(p_shared),
gemm_desc_ptr[group_id].block_2_ctile_map_,
a_element_op,
b_element_op,
c_element_op);
const auto karg = gemm_desc_ptr[group_id].karg_;
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op gemm_desc_ptr[group_id].block_2_ctile_map_);
#else
ignore = gemm_descs_const;
ignore = group_count;
......@@ -131,19 +133,24 @@ template <typename ALayout,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
bool PermuteB = false,
// MultipleD not supported for now.
enable_if_t<is_same_v<DsLayout, ck::Tuple<>> && is_same_v<DsDataType, ck::Tuple<>>,
bool> = false>
>
struct DeviceGroupedGemmXdlSplitKCShuffle
: public DeviceGroupedGemmSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
......@@ -198,7 +205,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
ComputeTypeB,
PermuteA,
PermuteB>;
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
......@@ -209,16 +218,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
struct GemmTransKernelArg
{
KernelArgument karg_;
// GroupedGemmBlock2ETileMap block_2_ctile_map_;
GroupedGemmBlock2ETileMap block_2_ctile_map_;
index_t block_start_, block_end_;
GemmTransKernelArg() = default;
GemmTransKernelArg(KernelArgument&& karg,
// GroupedGemmBlock2ETileMap&& b2c_map,
GroupedGemmBlock2ETileMap&& b2c_map,
index_t block_start,
index_t block_end)
: karg_{karg},
// block_2_ctile_map_{b2c_map},
block_2_ctile_map_{b2c_map},
block_start_{block_start},
block_end_{block_end}
{
......@@ -234,8 +243,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Argument(std::vector<const void*>& p_a_grid,
std::vector<const void*>& p_b_grid,
std::vector<void*>& p_c_grid,
std::vector<GemmDesc>& gemm_descs)
: Argument(p_a_grid, p_b_grid, p_c_grid, gemm_descs, DefaultKBatch)
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op))
: Argument(p_a_grid, p_b_grid, p_c_grid, gemm_descs, DefaultKBatch, a_element_op, b_element_op, cde_element_op )
{
// TODO: use occupancy api to calculate appropriate batch size.
}
......@@ -244,7 +256,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std::vector<const void*>& p_b_grid,
std::vector<void*>& p_c_grid,
std::vector<GemmDesc>& gemm_descs,
index_t kbatch)
index_t kbatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op))
: K_BATCH{kbatch}
{
grid_size_ = 0;
......@@ -267,7 +282,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_;
if(M == 0)
if(M * N * K == 0)
{
skipped_group_count_++;
continue;
......@@ -277,12 +292,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_;
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(M, N, K_BATCH);
const auto local_b2c_tile_map = Block2ETileMap{gdx, gdy, gdz};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(M, N);
// const index_t grid_size_grp = gdx * gdy * gdz;
const auto local_b2c_tile_map = Block2ETileMap{M, N, 4};
index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(M, N);
grid_size_grp *= K_BATCH;
const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp;
......@@ -290,24 +302,27 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
grid_size_ += grid_size_grp;
// block-to-e-tile map
// auto grouped_block_2_ctile_map =
// GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
KernelArgument karg{type_convert<const ADataType*>(p_a_grid[i]),
type_convert<const BDataType*>(p_b_grid[i]),
{}, // p_ds_grid
type_convert<EDataType*>(p_c_grid[i]),
M,
N,
K,
stride_a,
stride_b,
{}, // StrideDs_
stride_c,
K_BATCH};
K_BATCH,
a_element_op,
b_element_op,
cde_element_op};
// gemm_kernel_args_.emplace_back(
// std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
gemm_kernel_args_.emplace_back(
std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
}
}
......@@ -326,28 +341,22 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
auto& karg = gemm_kernel_args_[i].karg_;
// const index_t m_padded = GridwiseGemm::CalculateMPadded(karg.M);
// const index_t n_padded = GridwiseGemm::CalculateNPadded(karg.N);
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) =
GridwiseGemm::CalculateGridSize(karg.M, karg.N, karg.KBatch);
const auto local_b2c_tile_map = Block2ETileMap{gdx, gdy, gdz};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(karg.M, karg.N);
// const index_t grid_size_grp = gdx * gdy * gdz;
const auto local_b2c_tile_map = Block2ETileMap{M, N, 4};
index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(M, N);
grid_size_grp *= K_BATCH;
const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
// auto grouped_block_2_ctile_map =
// GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
// block-to-e-tile map
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KBatch = K_BATCH;
// gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_[i].block_start_ = block_start;
gemm_kernel_args_[i].block_end_ = block_end;
karg.KBatch = K_BATCH;
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_[i].block_start_ = block_start;
gemm_kernel_args_[i].block_end_ = block_end;
}
}
......@@ -365,45 +374,53 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
bool all_have_main_k_block_loop{true};
bool all_have_kbatch_gt_one;
const auto& karg0 = arg.gemm_kernel_args_[0].karg_;
index_t k_grain0 = karg0.KBatch * KPerBlock;
index_t K_split0 = (karg0.K + k_grain0 - 1) / k_grain0 * KPerBlock;
bool all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split0);
const auto tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(K_split0);
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& karg = arg.gemm_kernel_args_[i].karg_;
all_have_kbatch_gt_one = karg.KBatch > 1;
index_t k_grain = arg.gemm_kernel_args_[i].karg_.KBatch * KPerBlock;
index_t K_split =
(arg.gemm_kernel_args_[i].karg_.K + k_grain - 1) / k_grain * KPerBlock;
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
if(stream_config.log_level_ > 0)
{
karg.Print();
}
auto kbatch = karg.KBatch;
const auto& karg = arg.gemm_kernel_args_[i].karg_;
if(!GridwiseGemm::CheckValidity(karg))
index_t k_grain = karg.KBatch * KPerBlock;
index_t K_split = (karg.K + k_grain - 1) / k_grain * KPerBlock;
bool not_all_have_main_k0_block_loop_same =
all_have_main_k_block_loop xor
GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
bool not_all_have_tail_num_same =
(tail_num == GridwiseGemm::CalculateKBlockLoopTailNum(K_split));
if(not_all_have_main_k0_block_loop_same)
{
std::ostringstream err;
err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
if(not_all_have_tail_num_same)
{
std::ostringstream err;
err << "Not all gemms have same TailNumber value! in " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(not_all_have_kbatch_value_same)
if(!GridwiseGemm::CheckValidity(karg))
{
std::ostringstream err;
err << "Not all gemms have same kbatch value (=1 or >1)! "
<< "group [" << i << "], kbatch: " << kbatch
<< ", group [0], kbatch: " << arg.gemm_kernel_args_[0].karg_.KBatch
<< " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
......@@ -418,64 +435,71 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
float ave_time = 0;
const auto Run = [&](const auto& kernel) {
if(all_have_kbatch_gt_one)
if(stream_config.flush_cache)
{
for(const auto& trans_arg : arg.gemm_kernel_args_)
{
const auto& karg = trans_arg.karg_;
hip_check_error(hipMemsetAsync(karg.p_c_grid,
0,
karg.M * karg.N * sizeof(EDataType),
stream_config.stream_id_));
}
const auto& arg_ = arg.gemm_kernel_args_[0].karg_;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer =
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
auto size_b_buffer =
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
// TODO: should be loop here through all groups
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_kernel_args_.size());
}
for(const auto& trans_arg : arg.gemm_kernel_args_)
else
{
const auto& karg = trans_arg.karg_;
ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(arg.grid_size_), dim3(BlockSize), 0, karg);
// TODO: should be loop here through all groups
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_kernel_args_.size());
}
};
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
// Calculate TailNumber for one
auto calculate_tail_number = [&]() {
index_t k_grain = arg.gemm_kernel_args_[0].karg_.KBatch * KPerBlock;
index_t K_split =
(arg.gemm_kernel_args_[0].karg_.K + k_grain - 1) / k_grain * KPerBlock;
return GridwiseGemm::CalculateKBlockLoopTailNum(K_split);
};
auto all_have_same_tail_number = [&]() {
// Calculate TailNumber for one
auto tail_number = calculate_tail_number();
// Calculate TailNumber for every other arg and compare
for(size_t i = 1; i < arg.gemm_kernel_args_.size(); ++i)
{
index_t k_grain = arg.gemm_kernel_args_[i].karg_.KBatch * KPerBlock;
index_t K_split =
(arg.gemm_kernel_args_[i].karg_.K + k_grain - 1) / k_grain * KPerBlock;
if(tail_number != GridwiseGemm::CalculateKBlockLoopTailNum(K_split))
{
return false;
}
}
return true;
};
auto throw_error = [&]() {
std::ostringstream err;
err << "Not all gemms have same TailNumber value! ";
throw std::runtime_error(err.str());
};
if(all_have_main_k_block_loop)
{
// Tail number always full
......@@ -485,19 +509,21 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
if(all_have_kbatch_gt_one)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
......@@ -507,24 +533,19 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
if(all_have_kbatch_gt_one)
{
if(calculate_tail_number() == TailNumber::One)
if(tail_num == TailNumber::One)
{
if(all_have_same_tail_number())
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else
{
throw_error();
}
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmTransKernelArg true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
//// TODO: Fix below as above!
else if(calculate_tail_number() == TailNumber::Full)
{
if(all_have_same_tail_number())
......@@ -1094,11 +1115,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std::vector<std::array<const void*, NumDTensor>>&,
std::vector<void*>& p_c_grid,
std::vector<GemmDesc> gemm_descs,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation)
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{p_a_grid, p_b_grid, p_c_grid, gemm_descs};
return Argument{
p_a_grid, p_b_grid, p_c_grid, gemm_descs, a_element_op, b_element_op, cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -1110,11 +1132,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std::vector<std::array<const void*, NumDTensor>>&,
std::vector<void*>& p_c_grid,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation) override
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(p_a_grid, p_b_grid, p_c_grid, gemm_descs);
return std::make_unique<Argument>(
p_a_grid, p_b_grid, p_c_grid, gemm_descs, a_element_op, b_element_op, cde_element_op);
}
// polymorphic
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment