Commit 58d84615 authored by rocking's avatar rocking
Browse files

Implement invoker

parent 055acace
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp" #include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -308,6 +309,32 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp ...@@ -308,6 +309,32 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
using DoutGridDesc_M_K = remove_cvref_t<tuple_element_t<0, DoutDinGridDesc>>; using DoutGridDesc_M_K = remove_cvref_t<tuple_element_t<0, DoutDinGridDesc>>;
using DinGridDesc_M = remove_cvref_t<tuple_element_t<1, DoutDinGridDesc>>; using DinGridDesc_M = remove_cvref_t<tuple_element_t<1, DoutDinGridDesc>>;
// FIXME
// for NDHWC, the dim C is the vector Dim for both input and output in memory, which is not
// reduced. Assume C is the fastest dimension
static constexpr index_t InSrcOutDstVectorDim = 0;
using PassThrough = tensor_operation::element_wise::PassThrough;
using Div = tensor_operation::element_wise::UnaryDivide;
using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise<DOutDataType,
DInDataType,
ComputeDataType,
int,
DoutGridDesc_M_K,
DinGridDesc_M,
reduce::Add,
PassThrough,
Div,
InMemoryDataOperationEnum::Set,
false, // propagate_nan
BlockSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcOutDstVectorDim,
InSrcOutDstVectorSize,
InSrcOutDstVectorSize>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const DOutDataType* p_dout, Argument(const DOutDataType* p_dout,
...@@ -321,7 +348,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp ...@@ -321,7 +348,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
std::vector<ck::index_t> window_dilations, std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) std::vector<ck::index_t> input_right_pads)
: p_dout_grid_{p_dout}, p_din_grid_{p_din}, num_reduce_{1} : p_dout_grid_{p_dout},
p_din_grid_{p_din},
num_reduce_{1},
div_element_op_{window_lengths[0] * window_lengths[1] * window_lengths[2]}
{ {
std::vector<ck::index_t> Tildes(NDimSpatial); std::vector<ck::index_t> Tildes(NDimSpatial);
for(int i = 0; i < NDimSpatial; ++i) for(int i = 0; i < NDimSpatial; ++i)
...@@ -369,35 +399,67 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp ...@@ -369,35 +399,67 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
} }
} }
void Print() const
{
for(index_t i = 0; i < num_reduce_; i++)
{
std::cout << "dout_grid_desc_m_k_container_" << dout_grid_desc_m_k_container_[i]
<< std::endl;
std::cout << "din_grid_desc_m_container_" << din_grid_desc_m_container_[i]
<< std::endl;
}
}
// pointer
const DOutDataType* p_dout_grid_; const DOutDataType* p_dout_grid_;
DInDataType* p_din_grid_; DInDataType* p_din_grid_;
int num_reduce_; int num_reduce_;
std::vector<DoutGridDesc_M_K> dout_grid_desc_m_k_container_; std::vector<DoutGridDesc_M_K> dout_grid_desc_m_k_container_;
std::vector<DinGridDesc_M> din_grid_desc_m_container_; std::vector<DinGridDesc_M> din_grid_desc_m_container_;
Div div_element_op_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
ignore = arg;
ignore = stream_config;
float ave_time = 0;
for(index_t i = 0; i < arg.num_reduce_; i++)
{
const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
false,
false,
false, // don't have index input
DOutDataType,
DInDataType,
ComputeDataType,
int,
DoutGridDesc_M_K,
DinGridDesc_M,
PassThrough,
Div>;
ck::index_t M = arg.dout_grid_desc_m_k_container_[i].GetLength(I0);
const index_t grid_size = (M / M_BlockTileSize);
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.dout_grid_desc_m_k_container_[i],
arg.din_grid_desc_m_container_[i],
PassThrough{},
arg.div_element_op_,
float(1),
arg.p_dout_grid_,
nullptr,
float(0),
arg.p_din_grid_,
nullptr);
}
return ave_time;
}
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override const StreamConfig& stream_config = StreamConfig{}) override
{ {
ignore = p_arg; return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
ignore = stream_config;
return 0;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment