"...composable_kernel.git" did not exist on "29087570093f38075ed25d48b3f5c4d2885e47fa"
Commit 38962b98 authored by rocking's avatar rocking
Browse files

Calculate gridsize according to the number of CU

parent 3550cefe
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -94,15 +95,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -94,15 +95,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
: p_dout_{p_dout}, : p_dout_{p_dout},
p_indices_{p_indices}, p_indices_{p_indices},
p_din_{p_din}, p_din_{p_din},
dout_length_raw_{dout_length},
din_length_raw_{din_length}, din_length_raw_{din_length},
blockSize_{256}, blockSize_{256},
gridSize_{104}, // FIXME - Calculate the grid size by number of CU in the future
windowOverlap_{false} windowOverlap_{false}
{ {
index_t loop_step = gridSize_ * blockSize_ * InOutVectorSize;
din_grid_desc_ = MakeDescriptor_M(din_length, loop_step);
dout_grid_desc_ = MakeDescriptor_M(dout_length, loop_step);
for(size_t i = 0; i < window_lengths.size(); ++i) for(size_t i = 0; i < window_lengths.size(); ++i)
{ {
windowOverlap_ |= window_lengths.at(i) > window_strides.at(i); windowOverlap_ |= window_lengths.at(i) > window_strides.at(i);
...@@ -112,18 +109,21 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -112,18 +109,21 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
const DOutDataType* p_dout_; const DOutDataType* p_dout_;
const IndexDataType* p_indices_; const IndexDataType* p_indices_;
DInDataType* p_din_; DInDataType* p_din_;
index_t dout_length_raw_;
index_t din_length_raw_; index_t din_length_raw_;
index_t blockSize_; index_t blockSize_;
index_t gridSize_;
bool windowOverlap_; bool windowOverlap_;
InOutGrid1dDesc din_grid_desc_;
InOutGrid1dDesc dout_grid_desc_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
index_t gridSize = getAvailableComputeUnitCount(stream_config);
index_t loop_step = gridSize * arg.blockSize_ * InOutVectorSize;
InOutGrid1dDesc din_grid_desc = MakeDescriptor_M(arg.din_length_raw_, loop_step);
InOutGrid1dDesc dout_grid_desc = MakeDescriptor_M(arg.dout_length_raw_, loop_step);
if constexpr(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>) if constexpr(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>)
{ {
hip_check_error(hipMemsetAsync(arg.p_din_, hip_check_error(hipMemsetAsync(arg.p_din_,
...@@ -142,10 +142,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -142,10 +142,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
put_kernel, put_kernel,
dim3(arg.gridSize_), dim3(gridSize),
dim3(arg.blockSize_), dim3(arg.blockSize_),
0, 0,
arg.dout_grid_desc_, dout_grid_desc,
arg.p_dout_, arg.p_dout_,
arg.p_indices_, arg.p_indices_,
arg.p_din_, arg.p_din_,
...@@ -162,10 +162,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -162,10 +162,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
put_kernel, put_kernel,
dim3(arg.gridSize_), dim3(gridSize),
dim3(arg.blockSize_), dim3(arg.blockSize_),
0, 0,
arg.dout_grid_desc_, dout_grid_desc,
arg.p_dout_, arg.p_dout_,
arg.p_indices_, arg.p_indices_,
arg.p_din_, arg.p_din_,
...@@ -203,10 +203,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -203,10 +203,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
float elapsed_time = launch_and_time_kernel( float elapsed_time = launch_and_time_kernel(
stream_config, stream_config,
put_kernel, put_kernel,
dim3(arg.gridSize_), dim3(gridSize),
dim3(arg.blockSize_), dim3(arg.blockSize_),
0, 0,
arg.dout_grid_desc_, dout_grid_desc,
arg.p_dout_, arg.p_dout_,
arg.p_indices_, arg.p_indices_,
static_cast<DInDataType_AutomicAddPreCast*>(arg.p_workspace_), static_cast<DInDataType_AutomicAddPreCast*>(arg.p_workspace_),
...@@ -215,11 +215,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -215,11 +215,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
elapsed_time += launch_and_time_kernel( elapsed_time += launch_and_time_kernel(
stream_config, stream_config,
cast_kernel, cast_kernel,
dim3(arg.gridSize_), dim3(gridSize),
dim3(arg.blockSize_), dim3(arg.blockSize_),
0, 0,
ck::make_tuple(arg.din_grid_desc_), ck::make_tuple(din_grid_desc),
ck::make_tuple(arg.din_grid_desc_), ck::make_tuple(din_grid_desc),
static_cast<DInDataType_AutomicAddPreCast*>(arg.p_workspace_), static_cast<DInDataType_AutomicAddPreCast*>(arg.p_workspace_),
arg.p_din_, arg.p_din_,
UnaryConvert{}); UnaryConvert{});
...@@ -242,10 +242,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -242,10 +242,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
put_kernel, put_kernel,
dim3(arg.gridSize_), dim3(gridSize),
dim3(arg.blockSize_), dim3(arg.blockSize_),
0, 0,
arg.dout_grid_desc_, dout_grid_desc,
arg.p_dout_, arg.p_dout_,
arg.p_indices_, arg.p_indices_,
arg.p_din_, arg.p_din_,
...@@ -277,9 +277,8 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -277,9 +277,8 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
{ {
const Argument* pArg = dynamic_cast<const Argument*>(p_arg); const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
index_t din_length = pArg->din_grid_desc_.GetTransforms()[I0].GetUpperLengths()[I0]; if(pArg->din_length_raw_ % InOutVectorSize != 0 ||
index_t dout_length = pArg->dout_grid_desc_.GetTransforms()[I0].GetUpperLengths()[I0]; pArg->dout_length_raw_ % InOutVectorSize != 0)
if(din_length % InOutVectorSize != 0 || dout_length % InOutVectorSize != 0)
{ {
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment