Commit 7f09b8a0 authored by rocking's avatar rocking
Browse files

Support f16 and bf16

parent acd980fc
...@@ -174,6 +174,12 @@ bool maxpool_bwd_test(bool do_verification, ...@@ -174,6 +174,12 @@ bool maxpool_bwd_test(bool do_verification,
"not support this problem"); "not support this problem");
} }
size_t pool_bwd_workspace_sz = pool_bwd.GetWorkSpaceSize(pool_bwd_argument_ptr.get());
DeviceMem pool_bwd_workspace_device_buf(pool_bwd_workspace_sz);
pool_bwd_workspace_device_buf.SetZero();
pool_bwd.SetWorkSpacePointer(pool_bwd_argument_ptr.get(),
pool_bwd_workspace_device_buf.GetDeviceBuffer());
float ave_time_bwd = float ave_time_bwd =
pool_bwd_invoker_ptr->Run(pool_bwd_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); pool_bwd_invoker_ptr->Run(pool_bwd_argument_ptr.get(), StreamConfig{nullptr, time_kernel});
...@@ -204,7 +210,6 @@ bool maxpool_bwd_test(bool do_verification, ...@@ -204,7 +210,6 @@ bool maxpool_bwd_test(bool do_verification,
window_strides, window_strides,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads);
ref_pooling_fwd_invoker.Run(ref_pooling_fwd_argument); ref_pooling_fwd_invoker.Run(ref_pooling_fwd_argument);
using ReferencePoolingBwdInstance = ck::tensor_operation::host:: using ReferencePoolingBwdInstance = ck::tensor_operation::host::
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
#include "maxpool2d_bwd_common.hpp" #include "maxpool2d_bwd_common.hpp"
using InDataType = float; using InDataType = ck::half_t;
using OutDataType = float; using OutDataType = ck::half_t;
using IndexDataType = int32_t; using IndexDataType = int32_t;
using ComputeDataType = float; using ComputeDataType = float;
using DInDataType = float; using DInDataType = float;
...@@ -29,12 +29,12 @@ int main() ...@@ -29,12 +29,12 @@ int main()
// Pool shape // Pool shape
ck::index_t N = 1; ck::index_t N = 1;
ck::index_t C = 1; ck::index_t C = 1;
ck::index_t Y = 2; ck::index_t Y = 3;
ck::index_t X = 2; ck::index_t X = 3;
ck::index_t Hi = 31; ck::index_t Hi = 31;
ck::index_t Wi = 31; ck::index_t Wi = 31;
ck::index_t window_stride_h = 2; ck::index_t window_stride_h = 1;
ck::index_t window_stride_w = 2; ck::index_t window_stride_w = 1;
ck::index_t in_left_pad_h = 0; ck::index_t in_left_pad_h = 0;
ck::index_t in_left_pad_w = 0; ck::index_t in_left_pad_w = 0;
ck::index_t in_right_pad_h = 1; ck::index_t in_right_pad_h = 1;
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp" #include "ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -23,22 +24,24 @@ namespace device { ...@@ -23,22 +24,24 @@ namespace device {
template <typename DOutDataType, template <typename DOutDataType,
typename IndexDataType, typename IndexDataType,
typename DInDataType, typename DInDataType,
ck::index_t InVectorSize> ck::index_t InOutVectorSize>
struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDataType, DInDataType> struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDataType, DInDataType>
{ {
static_assert(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>, using DInDataType_AutomicAddPreCast =
"Data type is not supported!"); conditional_t<is_same_v<DInDataType, float> || is_same_v<DInDataType, double>,
DInDataType,
float>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using UnaryConvert = ck::tensor_operation::element_wise::UnaryConvert;
static constexpr auto I0 = Number<0>{};
template <typename Desc_M> template <typename Desc_M>
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) static auto PadDescriptor_M_1d(Desc_M desc_m, index_t loop_step)
{ {
constexpr auto I0 = Number<0>{}; const auto m = desc_m.GetLength(I0);
const auto pad = math::integer_least_multiple(m, loop_step) - m;
const auto m = desc_m.GetLength(I0);
const index_t loop_step = gridSize * blockSize * InVectorSize;
const auto pad = math::integer_least_multiple(m, loop_step) - m;
const auto desc_m_pad = const auto desc_m_pad =
transform_tensor_descriptor(desc_m, transform_tensor_descriptor(desc_m,
make_tuple(make_right_pad_transform(m, pad)), make_tuple(make_right_pad_transform(m, pad)),
...@@ -47,29 +50,38 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -47,29 +50,38 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
return desc_m_pad; return desc_m_pad;
} }
static auto MakeDescriptor_M(index_t length, index_t gridSize, index_t blockSize) static auto MakeDescriptor_M(index_t length, index_t loop_step)
{ {
const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length)); const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length));
return PadDescriptor_M_1d(desc_m, gridSize, blockSize); return PadDescriptor_M_1d(desc_m, loop_step);
} }
using OutGrid1dDesc = decltype(MakeDescriptor_M(1, 1, 1)); using InOutGrid1dDesc = decltype(MakeDescriptor_M(1, 1));
using GridwisePutElementSet = GridwisePutElement_1D<OutGrid1dDesc, using GridwisePutElementSet = GridwisePutElement_1D<InOutGrid1dDesc,
DOutDataType, DOutDataType,
IndexDataType, IndexDataType,
DInDataType, DInDataType,
PassThrough, PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
InVectorSize>; InOutVectorSize>;
using GridwisePutElementAtomicAdd = GridwisePutElement_1D<OutGrid1dDesc, using GridwisePutElementAtomicAdd = GridwisePutElement_1D<InOutGrid1dDesc,
DOutDataType, DOutDataType,
IndexDataType, IndexDataType,
DInDataType, DInDataType_AutomicAddPreCast,
PassThrough, PassThrough,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
InVectorSize>; InOutVectorSize>;
using GridwiseCasting = GridwiseElementwise_1D<Tuple<InOutGrid1dDesc>,
Tuple<InOutGrid1dDesc>,
Tuple<const DInDataType_AutomicAddPreCast*>,
Tuple<DInDataType*>,
UnaryConvert,
InOutVectorSize,
Sequence<InOutVectorSize>,
Sequence<InOutVectorSize>>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -77,6 +89,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -77,6 +89,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
const IndexDataType* p_indices, const IndexDataType* p_indices,
DInDataType* p_din, DInDataType* p_din,
index_t dout_length, index_t dout_length,
index_t din_length,
const std::vector<ck::index_t>& window_lengths, const std::vector<ck::index_t>& window_lengths,
const std::vector<ck::index_t>& window_strides) const std::vector<ck::index_t>& window_strides)
: p_dout_{p_dout}, : p_dout_{p_dout},
...@@ -86,7 +99,9 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -86,7 +99,9 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
gridSize_{104}, // FIXME - Calculate the grid size by number of CU in the future gridSize_{104}, // FIXME - Calculate the grid size by number of CU in the future
windowOverlap_{false} windowOverlap_{false}
{ {
dout_grid_desc_ = MakeDescriptor_M(dout_length, gridSize_, blockSize_); index_t loop_step = gridSize_ * blockSize_ * InOutVectorSize;
din_grid_desc_ = MakeDescriptor_M(din_length, loop_step);
dout_grid_desc_ = MakeDescriptor_M(dout_length, loop_step);
for(size_t i = 0; i < window_lengths.size(); ++i) for(size_t i = 0; i < window_lengths.size(); ++i)
{ {
...@@ -100,45 +115,126 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -100,45 +115,126 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
index_t blockSize_; index_t blockSize_;
index_t gridSize_; index_t gridSize_;
bool windowOverlap_; bool windowOverlap_;
OutGrid1dDesc dout_grid_desc_; InOutGrid1dDesc din_grid_desc_;
InOutGrid1dDesc dout_grid_desc_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
constexpr auto KernelSelector(bool windowOverlap) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(windowOverlap) if constexpr(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>)
return kernel_put_element_1d<GridwisePutElementAtomicAdd, {
OutGrid1dDesc, if(arg.windowOverlap_)
DOutDataType, {
IndexDataType, const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd,
DInDataType, InOutGrid1dDesc,
PassThrough>; DOutDataType,
IndexDataType,
DInDataType,
PassThrough>;
return launch_and_time_kernel(stream_config,
put_kernel,
dim3(arg.gridSize_),
dim3(arg.blockSize_),
0,
arg.dout_grid_desc_,
arg.p_dout_,
arg.p_indices_,
arg.p_din_,
PassThrough{});
}
else
{
const auto put_kernel = kernel_put_element_1d<GridwisePutElementSet,
InOutGrid1dDesc,
DOutDataType,
IndexDataType,
DInDataType,
PassThrough>;
return launch_and_time_kernel(stream_config,
put_kernel,
dim3(arg.gridSize_),
dim3(arg.blockSize_),
0,
arg.dout_grid_desc_,
arg.p_dout_,
arg.p_indices_,
arg.p_din_,
PassThrough{});
}
}
else else
return kernel_put_element_1d<GridwisePutElementSet, {
OutGrid1dDesc, if(arg.windowOverlap_)
DOutDataType, {
IndexDataType, if(arg.p_workspace_ == nullptr)
DInDataType, throw std::runtime_error("wrong! WorkSpace pointer has not been set");
PassThrough>;
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd,
{ InOutGrid1dDesc,
const auto kernel = KernelSelector(arg.windowOverlap_); DOutDataType,
IndexDataType,
float elapsed_time = launch_and_time_kernel(stream_config, DInDataType_AutomicAddPreCast,
kernel, PassThrough>;
dim3(arg.gridSize_),
dim3(arg.blockSize_), const auto cast_kernel =
0, kernel_elementwise_1d<GridwiseCasting,
arg.dout_grid_desc_, Tuple<InOutGrid1dDesc>,
arg.p_dout_, Tuple<InOutGrid1dDesc>,
arg.p_indices_, Tuple<const DInDataType_AutomicAddPreCast*>,
arg.p_din_, Tuple<DInDataType*>,
PassThrough{}); UnaryConvert>;
return elapsed_time; float elapsed_time = launch_and_time_kernel(
stream_config,
put_kernel,
dim3(arg.gridSize_),
dim3(arg.blockSize_),
0,
arg.dout_grid_desc_,
arg.p_dout_,
arg.p_indices_,
static_cast<DInDataType_AutomicAddPreCast*>(arg.p_workspace_),
PassThrough{});
elapsed_time += launch_and_time_kernel(
stream_config,
cast_kernel,
dim3(arg.gridSize_),
dim3(arg.blockSize_),
0,
ck::make_tuple(arg.din_grid_desc_),
ck::make_tuple(arg.din_grid_desc_),
static_cast<DInDataType_AutomicAddPreCast*>(arg.p_workspace_),
arg.p_din_,
UnaryConvert{});
return elapsed_time;
}
else
{
const auto put_kernel = kernel_put_element_1d<GridwisePutElementSet,
InOutGrid1dDesc,
DOutDataType,
IndexDataType,
DInDataType,
PassThrough>;
return launch_and_time_kernel(stream_config,
put_kernel,
dim3(arg.gridSize_),
dim3(arg.blockSize_),
0,
arg.dout_grid_desc_,
arg.p_dout_,
arg.p_indices_,
arg.p_din_,
PassThrough{});
}
}
} }
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
...@@ -148,11 +244,31 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -148,11 +244,31 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
} }
}; };
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
bool needCast = pArg_->windowOverlap_ &&
!(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>);
if(!needCast)
return 0;
else
{
index_t din_length = pArg_->din_grid_desc_.GetTransforms()[I0].GetUpperLengths()[I0];
return din_length * sizeof(DInDataType_AutomicAddPreCast);
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
{ {
const Argument* pArg = dynamic_cast<const Argument*>(p_arg); const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
// TODO index_t din_length = pArg->din_grid_desc_.GetTransforms()[I0].GetUpperLengths()[I0];
ignore = pArg; index_t dout_length = pArg->dout_grid_desc_.GetTransforms()[I0].GetUpperLengths()[I0];
if(din_length % InOutVectorSize != 0 || dout_length % InOutVectorSize != 0)
{
return false;
}
return true; return true;
} }
...@@ -161,7 +277,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -161,7 +277,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
const void* p_indices, const void* p_indices,
void* p_din, void* p_din,
index_t dout_length, index_t dout_length,
index_t, index_t din_length,
std::vector<ck::index_t> window_lengths, std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> window_strides) override std::vector<ck::index_t> window_strides) override
{ {
...@@ -169,6 +285,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat ...@@ -169,6 +285,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
static_cast<const IndexDataType*>(p_indices), static_cast<const IndexDataType*>(p_indices),
static_cast<DInDataType*>(p_din), static_cast<DInDataType*>(p_din),
dout_length, dout_length,
din_length,
window_lengths, window_lengths,
window_strides); window_strides);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment