Commit 38962b98 authored by rocking's avatar rocking
Browse files

Calculate gridsize according to the number of CU

parent 3550cefe
......@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace ck {
namespace tensor_operation {
......@@ -94,15 +95,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
: p_dout_{p_dout},
p_indices_{p_indices},
p_din_{p_din},
dout_length_raw_{dout_length},
din_length_raw_{din_length},
blockSize_{256},
gridSize_{104}, // FIXME - Calculate the grid size by number of CU in the future
windowOverlap_{false}
{
index_t loop_step = gridSize_ * blockSize_ * InOutVectorSize;
din_grid_desc_ = MakeDescriptor_M(din_length, loop_step);
dout_grid_desc_ = MakeDescriptor_M(dout_length, loop_step);
for(size_t i = 0; i < window_lengths.size(); ++i)
{
windowOverlap_ |= window_lengths.at(i) > window_strides.at(i);
......@@ -112,18 +109,21 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
const DOutDataType* p_dout_;
const IndexDataType* p_indices_;
DInDataType* p_din_;
index_t dout_length_raw_;
index_t din_length_raw_;
index_t blockSize_;
index_t gridSize_;
bool windowOverlap_;
InOutGrid1dDesc din_grid_desc_;
InOutGrid1dDesc dout_grid_desc_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
index_t gridSize = getAvailableComputeUnitCount(stream_config);
index_t loop_step = gridSize * arg.blockSize_ * InOutVectorSize;
InOutGrid1dDesc din_grid_desc = MakeDescriptor_M(arg.din_length_raw_, loop_step);
InOutGrid1dDesc dout_grid_desc = MakeDescriptor_M(arg.dout_length_raw_, loop_step);
if constexpr(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>)
{
hip_check_error(hipMemsetAsync(arg.p_din_,
......@@ -142,10 +142,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
return launch_and_time_kernel(stream_config,
put_kernel,
dim3(arg.gridSize_),
dim3(gridSize),
dim3(arg.blockSize_),
0,
arg.dout_grid_desc_,
dout_grid_desc,
arg.p_dout_,
arg.p_indices_,
arg.p_din_,
......@@ -162,10 +162,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
return launch_and_time_kernel(stream_config,
put_kernel,
dim3(arg.gridSize_),
dim3(gridSize),
dim3(arg.blockSize_),
0,
arg.dout_grid_desc_,
dout_grid_desc,
arg.p_dout_,
arg.p_indices_,
arg.p_din_,
......@@ -203,10 +203,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
float elapsed_time = launch_and_time_kernel(
stream_config,
put_kernel,
dim3(arg.gridSize_),
dim3(gridSize),
dim3(arg.blockSize_),
0,
arg.dout_grid_desc_,
dout_grid_desc,
arg.p_dout_,
arg.p_indices_,
static_cast<DInDataType_AutomicAddPreCast*>(arg.p_workspace_),
......@@ -215,11 +215,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
elapsed_time += launch_and_time_kernel(
stream_config,
cast_kernel,
dim3(arg.gridSize_),
dim3(gridSize),
dim3(arg.blockSize_),
0,
ck::make_tuple(arg.din_grid_desc_),
ck::make_tuple(arg.din_grid_desc_),
ck::make_tuple(din_grid_desc),
ck::make_tuple(din_grid_desc),
static_cast<DInDataType_AutomicAddPreCast*>(arg.p_workspace_),
arg.p_din_,
UnaryConvert{});
......@@ -242,10 +242,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
return launch_and_time_kernel(stream_config,
put_kernel,
dim3(arg.gridSize_),
dim3(gridSize),
dim3(arg.blockSize_),
0,
arg.dout_grid_desc_,
dout_grid_desc,
arg.p_dout_,
arg.p_indices_,
arg.p_din_,
......@@ -277,9 +277,8 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
index_t din_length = pArg->din_grid_desc_.GetTransforms()[I0].GetUpperLengths()[I0];
index_t dout_length = pArg->dout_grid_desc_.GetTransforms()[I0].GetUpperLengths()[I0];
if(din_length % InOutVectorSize != 0 || dout_length % InOutVectorSize != 0)
if(pArg->din_length_raw_ % InOutVectorSize != 0 ||
pArg->dout_length_raw_ % InOutVectorSize != 0)
{
return false;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment