"vscode:/vscode.git/clone" did not exist on "eb898ad6cac26fecfcd024ccbe863d655f1dfd48"
Commit 400cb28e authored by rocking's avatar rocking
Browse files

Support different stride

parent 0bd6d2ce
......@@ -4,6 +4,7 @@
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
......@@ -20,6 +21,46 @@ using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DOutLayout = ck::tensor_layout::convolution::NDHWC;
using DInLayout = ck::tensor_layout::convolution::NDHWC;
template <typename TensorLayout>
std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
ck::index_t C_,
ck::index_t D,
ck::index_t H,
ck::index_t W,
TensorLayout layout)
{
using namespace ck::literals;
(void)N_;
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCDHW>::value)
return {C_ * D * H * W, D * H * W, H * W, W, 1_uz};
else if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NDHWC>::value)
return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_};
};
template <typename TensorLayout>
HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
std::size_t C_,
std::size_t D,
std::size_t H,
std::size_t W,
TensorLayout layout)
{
using namespace ck::literals;
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCDHW>::value)
{
return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz});
}
else if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NDHWC>::value)
{
return HostTensorDescriptor({N_, C_, D, H, W},
{D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_});
}
};
bool pool3d_bwd_test(bool do_verification,
bool time_kernel,
ck::index_t N,
......@@ -42,7 +83,8 @@ bool pool3d_bwd_test(bool do_verification,
1, // ReduceKThreadClusterSize
1, // ReduceMThreadSliceSize
1, // ReduceKThreadSliceSize
1>; // InSrcOutDstVectorSize
1, // InSrcOutDstVectorSize
true>;
auto OutSpatialLength = [&](auto InSpatialLength, int index) {
ck::index_t left_pad = dinput_left_pads[index];
......@@ -56,16 +98,9 @@ bool pool3d_bwd_test(bool do_verification,
ck::index_t Ho = OutSpatialLength(Hi, 1);
ck::index_t Wo = OutSpatialLength(Wi, 1);
auto f_host_tensor_descriptor =
[](std::size_t N_, std::size_t C_, std::size_t D, std::size_t H, std::size_t W) {
using namespace ck::literals;
return HostTensorDescriptor({N_, C_, D, H, W},
{D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_});
};
Tensor<DOutDataType> dout(f_host_tensor_descriptor(N, C, Do, Ho, Wo));
Tensor<DInDataType> din_dev(f_host_tensor_descriptor(N, C, Di, Hi, Wi));
Tensor<DInDataType> din_host(f_host_tensor_descriptor(N, C, Di, Hi, Wi));
Tensor<DOutDataType> dout(f_host_tensor_descriptor(N, C, Do, Ho, Wo, DOutLayout{}));
Tensor<DInDataType> din_dev(f_host_tensor_descriptor(N, C, Di, Hi, Wi, DInLayout{}));
Tensor<DInDataType> din_host(f_host_tensor_descriptor(N, C, Di, Hi, Wi, DInLayout{}));
std::cout << "dout: " << dout.mDesc << std::endl;
std::cout << "din_host: " << din_host.mDesc << std::endl;
......@@ -76,6 +111,7 @@ bool pool3d_bwd_test(bool do_verification,
DeviceMem din_device_buf(sizeof(DInDataType) * din_dev.mDesc.GetElementSpaceSize());
dout_device_buf.ToDevice(dout.mData.data());
din_device_buf.SetZero();
auto pool = DevicePoolBwdInstance{};
auto invoker_ptr = pool.MakeInvokerPointer();
......@@ -84,8 +120,8 @@ bool pool3d_bwd_test(bool do_verification,
static_cast<DInDataType*>(din_device_buf.GetDeviceBuffer()),
{N, C, Do, Ho, Wo},
{N, C, Di, Hi, Wi},
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C},
{Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C},
f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, DOutLayout{}),
f_tensor_strides_ncdhw(N, C, Di, Hi, Wi, DInLayout{}),
window_lengths,
window_strides,
window_dilations,
......@@ -131,7 +167,7 @@ int main()
{
std::vector<ck::index_t> window_lengths = {5, 5, 5};
std::vector<ck::index_t> window_strides = {2, 2, 2};
std::vector<ck::index_t> window_dilations = {1, 1, 1};
std::vector<ck::index_t> window_dilations = {2, 2, 2};
std::vector<ck::index_t> dinput_left_pads = {0, 0, 0};
std::vector<ck::index_t> dinput_right_pads = {0, 0, 0};
......
......@@ -18,6 +18,11 @@ namespace ck {
namespace tensor_operation {
namespace device {
// In and Din = [N, C, Di, Hi, Wi]
// Out and Dout = [N, C, Do, Ho, Wo]
// Out = AvgPoolFwd(In)
// Din = AvgPoolBwd(Dout)
// Pooling dimension = D, H, W
template <typename DOutDataType,
typename DInDataType,
typename ComputeDataType,
......@@ -26,7 +31,8 @@ template <typename DOutDataType,
ck::index_t KThreadClusterSize,
ck::index_t MThreadSliceSize,
ck::index_t KThreadSliceSize,
ck::index_t InSrcOutDstVectorSize>
ck::index_t InSrcOutDstVectorSize,
bool IsFastestDimReduced>
struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataType>
{
static constexpr index_t NDimSpatial = 3;
......@@ -312,7 +318,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
// FIXME
// for NDHWC, the dim C is the vector Dim for both input and output in memory, which is not
// reduced. Assume C is the fastest dimension
static constexpr index_t InSrcOutDstVectorDim = 0;
static constexpr index_t OutSrcInDstVectorDim = IsFastestDimReduced ? 1 : 0;
using PassThrough = tensor_operation::element_wise::PassThrough;
using Div = tensor_operation::element_wise::UnaryDivide;
......@@ -331,7 +337,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
BlockSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcOutDstVectorDim,
OutSrcInDstVectorDim,
InSrcOutDstVectorSize,
InSrcOutDstVectorSize>;
......@@ -413,9 +419,6 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
ignore = arg;
ignore = stream_config;
float ave_time = 0;
for(index_t i = 0; i < arg.num_reduce_; i++)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment