Commit 400cb28e authored by rocking's avatar rocking
Browse files

Support different stride

parent 0bd6d2ce
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <iostream> #include <iostream>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -20,6 +21,46 @@ using ComputeDataType = float; ...@@ -20,6 +21,46 @@ using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DOutLayout = ck::tensor_layout::convolution::NDHWC;
using DInLayout = ck::tensor_layout::convolution::NDHWC;
template <typename TensorLayout>
std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
ck::index_t C_,
ck::index_t D,
ck::index_t H,
ck::index_t W,
TensorLayout layout)
{
using namespace ck::literals;
(void)N_;
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCDHW>::value)
return {C_ * D * H * W, D * H * W, H * W, W, 1_uz};
else if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NDHWC>::value)
return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_};
};
template <typename TensorLayout>
HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
std::size_t C_,
std::size_t D,
std::size_t H,
std::size_t W,
TensorLayout layout)
{
using namespace ck::literals;
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCDHW>::value)
{
return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz});
}
else if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NDHWC>::value)
{
return HostTensorDescriptor({N_, C_, D, H, W},
{D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_});
}
};
bool pool3d_bwd_test(bool do_verification, bool pool3d_bwd_test(bool do_verification,
bool time_kernel, bool time_kernel,
ck::index_t N, ck::index_t N,
...@@ -42,7 +83,8 @@ bool pool3d_bwd_test(bool do_verification, ...@@ -42,7 +83,8 @@ bool pool3d_bwd_test(bool do_verification,
1, // ReduceKThreadClusterSize 1, // ReduceKThreadClusterSize
1, // ReduceMThreadSliceSize 1, // ReduceMThreadSliceSize
1, // ReduceKThreadSliceSize 1, // ReduceKThreadSliceSize
1>; // InSrcOutDstVectorSize 1, // InSrcOutDstVectorSize
true>;
auto OutSpatialLength = [&](auto InSpatialLength, int index) { auto OutSpatialLength = [&](auto InSpatialLength, int index) {
ck::index_t left_pad = dinput_left_pads[index]; ck::index_t left_pad = dinput_left_pads[index];
...@@ -56,16 +98,9 @@ bool pool3d_bwd_test(bool do_verification, ...@@ -56,16 +98,9 @@ bool pool3d_bwd_test(bool do_verification,
ck::index_t Ho = OutSpatialLength(Hi, 1); ck::index_t Ho = OutSpatialLength(Hi, 1);
ck::index_t Wo = OutSpatialLength(Wi, 1); ck::index_t Wo = OutSpatialLength(Wi, 1);
auto f_host_tensor_descriptor = Tensor<DOutDataType> dout(f_host_tensor_descriptor(N, C, Do, Ho, Wo, DOutLayout{}));
[](std::size_t N_, std::size_t C_, std::size_t D, std::size_t H, std::size_t W) { Tensor<DInDataType> din_dev(f_host_tensor_descriptor(N, C, Di, Hi, Wi, DInLayout{}));
using namespace ck::literals; Tensor<DInDataType> din_host(f_host_tensor_descriptor(N, C, Di, Hi, Wi, DInLayout{}));
return HostTensorDescriptor({N_, C_, D, H, W},
{D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_});
};
Tensor<DOutDataType> dout(f_host_tensor_descriptor(N, C, Do, Ho, Wo));
Tensor<DInDataType> din_dev(f_host_tensor_descriptor(N, C, Di, Hi, Wi));
Tensor<DInDataType> din_host(f_host_tensor_descriptor(N, C, Di, Hi, Wi));
std::cout << "dout: " << dout.mDesc << std::endl; std::cout << "dout: " << dout.mDesc << std::endl;
std::cout << "din_host: " << din_host.mDesc << std::endl; std::cout << "din_host: " << din_host.mDesc << std::endl;
...@@ -76,6 +111,7 @@ bool pool3d_bwd_test(bool do_verification, ...@@ -76,6 +111,7 @@ bool pool3d_bwd_test(bool do_verification,
DeviceMem din_device_buf(sizeof(DInDataType) * din_dev.mDesc.GetElementSpaceSize()); DeviceMem din_device_buf(sizeof(DInDataType) * din_dev.mDesc.GetElementSpaceSize());
dout_device_buf.ToDevice(dout.mData.data()); dout_device_buf.ToDevice(dout.mData.data());
din_device_buf.SetZero();
auto pool = DevicePoolBwdInstance{}; auto pool = DevicePoolBwdInstance{};
auto invoker_ptr = pool.MakeInvokerPointer(); auto invoker_ptr = pool.MakeInvokerPointer();
...@@ -84,8 +120,8 @@ bool pool3d_bwd_test(bool do_verification, ...@@ -84,8 +120,8 @@ bool pool3d_bwd_test(bool do_verification,
static_cast<DInDataType*>(din_device_buf.GetDeviceBuffer()), static_cast<DInDataType*>(din_device_buf.GetDeviceBuffer()),
{N, C, Do, Ho, Wo}, {N, C, Do, Ho, Wo},
{N, C, Di, Hi, Wi}, {N, C, Di, Hi, Wi},
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C}, f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, DOutLayout{}),
{Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C}, f_tensor_strides_ncdhw(N, C, Di, Hi, Wi, DInLayout{}),
window_lengths, window_lengths,
window_strides, window_strides,
window_dilations, window_dilations,
...@@ -131,7 +167,7 @@ int main() ...@@ -131,7 +167,7 @@ int main()
{ {
std::vector<ck::index_t> window_lengths = {5, 5, 5}; std::vector<ck::index_t> window_lengths = {5, 5, 5};
std::vector<ck::index_t> window_strides = {2, 2, 2}; std::vector<ck::index_t> window_strides = {2, 2, 2};
std::vector<ck::index_t> window_dilations = {1, 1, 1}; std::vector<ck::index_t> window_dilations = {2, 2, 2};
std::vector<ck::index_t> dinput_left_pads = {0, 0, 0}; std::vector<ck::index_t> dinput_left_pads = {0, 0, 0};
std::vector<ck::index_t> dinput_right_pads = {0, 0, 0}; std::vector<ck::index_t> dinput_right_pads = {0, 0, 0};
......
...@@ -18,6 +18,11 @@ namespace ck { ...@@ -18,6 +18,11 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// In and Din = [N, C, Di, Hi, Wi]
// Out and Dout = [N, C, Do, Ho, Wo]
// Out = AvgPoolFwd(In)
// Din = AvgPoolBwd(Dout)
// Pooling dimension = D, H, W
template <typename DOutDataType, template <typename DOutDataType,
typename DInDataType, typename DInDataType,
typename ComputeDataType, typename ComputeDataType,
...@@ -26,7 +31,8 @@ template <typename DOutDataType, ...@@ -26,7 +31,8 @@ template <typename DOutDataType,
ck::index_t KThreadClusterSize, ck::index_t KThreadClusterSize,
ck::index_t MThreadSliceSize, ck::index_t MThreadSliceSize,
ck::index_t KThreadSliceSize, ck::index_t KThreadSliceSize,
ck::index_t InSrcOutDstVectorSize> ck::index_t InSrcOutDstVectorSize,
bool IsFastestDimReduced>
struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataType> struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataType>
{ {
static constexpr index_t NDimSpatial = 3; static constexpr index_t NDimSpatial = 3;
...@@ -312,7 +318,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp ...@@ -312,7 +318,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
// FIXME // FIXME
// for NDHWC, the dim C is the vector Dim for both input and output in memory, which is not // for NDHWC, the dim C is the vector Dim for both input and output in memory, which is not
// reduced. Assume C is the fastest dimension // reduced. Assume C is the fastest dimension
static constexpr index_t InSrcOutDstVectorDim = 0; static constexpr index_t OutSrcInDstVectorDim = IsFastestDimReduced ? 1 : 0;
using PassThrough = tensor_operation::element_wise::PassThrough; using PassThrough = tensor_operation::element_wise::PassThrough;
using Div = tensor_operation::element_wise::UnaryDivide; using Div = tensor_operation::element_wise::UnaryDivide;
...@@ -331,7 +337,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp ...@@ -331,7 +337,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
BlockSize, BlockSize,
MThreadSliceSize, MThreadSliceSize,
KThreadSliceSize, KThreadSliceSize,
InSrcOutDstVectorDim, OutSrcInDstVectorDim,
InSrcOutDstVectorSize, InSrcOutDstVectorSize,
InSrcOutDstVectorSize>; InSrcOutDstVectorSize>;
...@@ -413,9 +419,6 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp ...@@ -413,9 +419,6 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
ignore = arg;
ignore = stream_config;
float ave_time = 0; float ave_time = 0;
for(index_t i = 0; i < arg.num_reduce_; i++) for(index_t i = 0; i < arg.num_reduce_; i++)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment