Commit 160932b6 authored by Adam Osewski's avatar Adam Osewski
Browse files

Multiple fixes.

* fix Accumulation when there's only one workgroup per K dim.
* Update occupancy values after KBatch update and fix it's calculation.
parent ccc38273
...@@ -79,7 +79,7 @@ __global__ void ...@@ -79,7 +79,7 @@ __global__ void
uint32_t* const __restrict__ p_flags = reinterpret_cast<uint32_t* const __restrict__>( uint32_t* const __restrict__ p_flags = reinterpret_cast<uint32_t* const __restrict__>(
reinterpret_cast<char*>(p_workspace) + reinterpret_cast<char*>(p_workspace) +
Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(typename GridwiseGemm::AccType))); Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(typename GridwiseGemm::CShuffleDataT)));
StridedReductionTileLoop work_scheduler{tile_count, p_flags}; StridedReductionTileLoop work_scheduler{tile_count, p_flags};
...@@ -183,10 +183,9 @@ __global__ void ...@@ -183,10 +183,9 @@ __global__ void
acc_buff{}; acc_buff{};
acc_buff.Clear(); acc_buff.Clear();
// Accumulate only when there is at least two workgroups processing splitk data-tiles // TODO: Accumulate only when there is at least two workgroups processing splitk
// across same MN-output tile. // data-tiles across same MN-output tile. if(neighbour_count > 0)
if(neighbour_count > 0) GridwiseGemm::AccumulatePartials(p_workspace, acc_buff, neighbour_count + 1);
GridwiseGemm::AccumulatePartials(p_workspace, acc_buff, neighbour_count + 1);
// Signal waiting blocks that they can start use their workspace. // Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(neighbour_count); work_scheduler.Reset(neighbour_count);
...@@ -366,9 +365,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -366,9 +365,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
std::vector<GemmDesc>& gemm_descs, std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op, CDEElementwiseOperation cde_element_op)
int occupancy_num_blocks,
int gpu_cu_count)
: Argument(p_As, : Argument(p_As,
p_Bs, p_Bs,
p_Ds, p_Ds,
...@@ -377,9 +374,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -377,9 +374,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op, cde_element_op,
DefaultKBatch, DefaultKBatch)
occupancy_num_blocks,
gpu_cu_count)
{ {
} }
...@@ -391,15 +386,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -391,15 +386,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op, CDEElementwiseOperation cde_element_op,
index_t kbatch, index_t kbatch)
int occupancy_num_blocks,
int gpu_cu_count)
: K_BATCH{kbatch}, : K_BATCH{kbatch},
group_count_{0}, group_count_{0},
skipped_group_count_{0}, skipped_group_count_{0},
tile_count_{0}, tile_count_{0},
occupancy_num_blocks_{occupancy_num_blocks}, occupancy_num_blocks_{0},
gpu_cu_count_{gpu_cu_count}, gpu_cu_count_{0},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op} cde_element_op_{cde_element_op}
...@@ -459,6 +452,14 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -459,6 +452,14 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
stride_ds, stride_ds,
stride_e); stride_e);
} }
UpdateOccupancy();
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
gpu_cu_count_ = dev_prop.multiProcessorCount;
} }
/** /**
...@@ -480,6 +481,79 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -480,6 +481,79 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
b2c_tile_map.CalculateGridSize(gemm_arg.M, gemm_arg.N); b2c_tile_map.CalculateGridSize(gemm_arg.M, gemm_arg.N);
tile_count_ += grid_size_grp; tile_count_ += grid_size_grp;
} }
UpdateOccupancy();
}
void UpdateOccupancy()
{
bool all_have_main_k_block_loop;
{
const auto a_grid_desc_kbatch_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(gemm_kernel_args_[0].M,
gemm_kernel_args_[0].K,
gemm_kernel_args_[0].StrideA,
K_BATCH);
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
}
for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
{
const auto& gemm_arg = gemm_kernel_args_[i];
auto kbatch = K_BATCH;
const auto a_grid_desc_kbatch_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
if(not_all_have_main_k_block_loop_same)
{
std::ostringstream err;
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
if(all_have_main_k_block_loop)
{
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
true>;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy_num_blocks_, kernel, BlockSize, 0));
}
else
{
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
false>;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy_num_blocks_, kernel, BlockSize, 0));
}
} }
// private: // private:
...@@ -657,7 +731,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -657,7 +731,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
const auto a_grid_desc_kbatch_ak0_m_ak1 = const auto a_grid_desc_kbatch_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1( GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, arg.K_BATCH); gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
bool not_all_have_main_k_block_loop_same = bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop( all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop(
...@@ -754,12 +828,10 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -754,12 +828,10 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
grid_size = (arg.tile_count_ + tiles_per_block - 1) / tiles_per_block; grid_size = (arg.tile_count_ + tiles_per_block - 1) / tiles_per_block;
} }
std::size_t acc_workspace_size_bytes = Block2ETileMapKSplit::GetAccWorkspaceSize( std::size_t acc_workspace_size_bytes =
sizeof(typename GridwiseGemm::AccType), grid_size); Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(CShuffleDataType), grid_size);
void* p_flags = reinterpret_cast<char*>(dev_gemm_workspace) + void* p_flags = reinterpret_cast<char*>(dev_gemm_workspace) + acc_workspace_size_bytes;
Block2ETileMapKSplit::GetAccWorkspaceSize(
sizeof(typename GridwiseGemm::AccType), grid_size);
std::size_t flag_count = grid_size; std::size_t flag_count = grid_size;
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
...@@ -858,27 +930,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -858,27 +930,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
BElementwiseOperation b_elementwise_op, BElementwiseOperation b_elementwise_op,
CDEElementwiseOperation cde_elementwise_op) CDEElementwiseOperation cde_elementwise_op)
{ {
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
true>;
int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
return Argument{p_As, return Argument{p_As,
p_Bs, p_Bs,
p_Ds, p_Ds,
...@@ -886,9 +937,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -886,9 +937,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
gemm_descs, gemm_descs,
a_elementwise_op, a_elementwise_op,
b_elementwise_op, b_elementwise_op,
cde_elementwise_op, cde_elementwise_op};
occupancy,
num_cu};
} }
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
...@@ -901,27 +950,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -901,27 +950,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
BElementwiseOperation b_elementwise_op, BElementwiseOperation b_elementwise_op,
CDEElementwiseOperation cde_elementwise_op) override CDEElementwiseOperation cde_elementwise_op) override
{ {
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
true>;
int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
return std::make_unique<Argument>(p_As, return std::make_unique<Argument>(p_As,
p_Bs, p_Bs,
p_Ds, p_Ds,
...@@ -929,9 +957,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle ...@@ -929,9 +957,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
gemm_descs, gemm_descs,
a_elementwise_op, a_elementwise_op,
b_elementwise_op, b_elementwise_op,
cde_elementwise_op, cde_elementwise_op);
occupancy,
num_cu);
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment