Commit 160932b6 authored by Adam Osewski's avatar Adam Osewski
Browse files

Multiple fixes.

* fix Accumulation when there's only one workgroup per K dim.
* Update occupancy values after KBatch update and fix it's calculation.
parent ccc38273
......@@ -79,7 +79,7 @@ __global__ void
uint32_t* const __restrict__ p_flags = reinterpret_cast<uint32_t* const __restrict__>(
reinterpret_cast<char*>(p_workspace) +
Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(typename GridwiseGemm::AccType)));
Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(typename GridwiseGemm::CShuffleDataT)));
StridedReductionTileLoop work_scheduler{tile_count, p_flags};
......@@ -183,9 +183,8 @@ __global__ void
acc_buff{};
acc_buff.Clear();
// Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile.
if(neighbour_count > 0)
// TODO: Accumulate only when there is at least two workgroups processing splitk
// data-tiles across same MN-output tile. if(neighbour_count > 0)
GridwiseGemm::AccumulatePartials(p_workspace, acc_buff, neighbour_count + 1);
// Signal waiting blocks that they can start use their workspace.
......@@ -366,9 +365,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
int occupancy_num_blocks,
int gpu_cu_count)
CDEElementwiseOperation cde_element_op)
: Argument(p_As,
p_Bs,
p_Ds,
......@@ -377,9 +374,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
a_element_op,
b_element_op,
cde_element_op,
DefaultKBatch,
occupancy_num_blocks,
gpu_cu_count)
DefaultKBatch)
{
}
......@@ -391,15 +386,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
index_t kbatch,
int occupancy_num_blocks,
int gpu_cu_count)
index_t kbatch)
: K_BATCH{kbatch},
group_count_{0},
skipped_group_count_{0},
tile_count_{0},
occupancy_num_blocks_{occupancy_num_blocks},
gpu_cu_count_{gpu_cu_count},
occupancy_num_blocks_{0},
gpu_cu_count_{0},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
......@@ -459,6 +452,14 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
stride_ds,
stride_e);
}
UpdateOccupancy();
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
gpu_cu_count_ = dev_prop.multiProcessorCount;
}
/**
......@@ -480,6 +481,79 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
b2c_tile_map.CalculateGridSize(gemm_arg.M, gemm_arg.N);
tile_count_ += grid_size_grp;
}
UpdateOccupancy();
}
void UpdateOccupancy()
{
bool all_have_main_k_block_loop;
{
const auto a_grid_desc_kbatch_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(gemm_kernel_args_[0].M,
gemm_kernel_args_[0].K,
gemm_kernel_args_[0].StrideA,
K_BATCH);
all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
}
for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
{
const auto& gemm_arg = gemm_kernel_args_[i];
auto kbatch = K_BATCH;
const auto a_grid_desc_kbatch_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3));
if(not_all_have_main_k_block_loop_same)
{
std::ostringstream err;
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
if(all_have_main_k_block_loop)
{
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
true>;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy_num_blocks_, kernel, BlockSize, 0));
}
else
{
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
false>;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy_num_blocks_, kernel, BlockSize, 0));
}
}
// private:
......@@ -657,7 +731,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
const auto a_grid_desc_kbatch_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_KBatch_AK0_M_AK1(
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, arg.K_BATCH);
gemm_arg.M, gemm_arg.K, gemm_arg.StrideA, kbatch);
bool not_all_have_main_k_block_loop_same =
all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop(
......@@ -754,12 +828,10 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
grid_size = (arg.tile_count_ + tiles_per_block - 1) / tiles_per_block;
}
std::size_t acc_workspace_size_bytes = Block2ETileMapKSplit::GetAccWorkspaceSize(
sizeof(typename GridwiseGemm::AccType), grid_size);
std::size_t acc_workspace_size_bytes =
Block2ETileMapKSplit::GetAccWorkspaceSize(sizeof(CShuffleDataType), grid_size);
void* p_flags = reinterpret_cast<char*>(dev_gemm_workspace) +
Block2ETileMapKSplit::GetAccWorkspaceSize(
sizeof(typename GridwiseGemm::AccType), grid_size);
void* p_flags = reinterpret_cast<char*>(dev_gemm_workspace) + acc_workspace_size_bytes;
std::size_t flag_count = grid_size;
if(stream_config.log_level_ > 0)
......@@ -858,27 +930,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
BElementwiseOperation b_elementwise_op,
CDEElementwiseOperation cde_elementwise_op)
{
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
true>;
int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
return Argument{p_As,
p_Bs,
p_Ds,
......@@ -886,9 +937,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
gemm_descs,
a_elementwise_op,
b_elementwise_op,
cde_elementwise_op,
occupancy,
num_cu};
cde_elementwise_op};
}
std::unique_ptr<BaseArgument>
......@@ -901,27 +950,6 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
BElementwiseOperation b_elementwise_op,
CDEElementwiseOperation cde_elementwise_op) override
{
const auto kernel = kernel_grouped_gemm_xdl_splitk_v2<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
DsDataType,
Block2ETileMapKSplit,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
true>;
int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
return std::make_unique<Argument>(p_As,
p_Bs,
p_Ds,
......@@ -929,9 +957,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
gemm_descs,
a_elementwise_op,
b_elementwise_op,
cde_elementwise_op,
occupancy,
num_cu);
cde_elementwise_op);
}
static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment