Unverified Commit ac76519a authored by Adam Osewski's avatar Adam Osewski Committed by GitHub
Browse files

Merge branch 'develop' into aosewski/gemm_tile_loop

parents a70c6283 578142db
add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp) add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp)
add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp) add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp)
add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp)
endif()
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp) add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp)
endif()
add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp) add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp)
add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp)
endif()
add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp) if(DL_KERNELS)
add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp)
endif()
add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp)
endif()
add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp)
endif()
add_example_executable(example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp) if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
add_example_executable(example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp) add_example_executable(example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp)
add_example_executable(example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp) endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_example_executable(example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp)
endif()
add_example_executable(example_put_element_fp16 put_element_fp16.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_example_executable(example_put_element_fp16 put_element_fp16.cpp)
endif()
add_example_executable(example_avgpool3d_bwd_bf16 avgpool3d_bwd_bf16.cpp)
add_example_executable(example_avgpool3d_bwd_fp16 avgpool3d_bwd_fp16.cpp)
add_example_executable(example_avgpool3d_bwd_fp32 avgpool3d_bwd_fp32.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp"
#include "avgpool3d_bwd_common.hpp"
using DOutDataType = ck::bhalf_t;
using DInDataType = ck::bhalf_t;
using ComputeDataType = float;
#if 1
using DOutLayout = ck::tensor_layout::convolution::NDHWC;
using DInLayout = ck::tensor_layout::convolution::NDHWC;
#else
using DOutLayout = ck::tensor_layout::convolution::NCDHW;
using DInLayout = ck::tensor_layout::convolution::NCDHW;
#endif
using DevicePoolBwdInstance =
ck::tensor_operation::device::DeviceAvgPool3dBwd_NDHWC_NDHWC<DOutDataType,
DInDataType,
ComputeDataType,
64, // BlockSize
64, // ReduceMThreadClusterSize
1, // ReduceKThreadClusterSize
1, // ReduceMThreadSliceSize
1, // ReduceKThreadSliceSize
1>; // InSrcOutDstVectorSize
int main()
{
std::vector<ck::index_t> window_lengths = {5, 5, 5};
std::vector<ck::index_t> window_strides = {2, 2, 2};
std::vector<ck::index_t> window_dilations = {2, 2, 2};
std::vector<ck::index_t> dinput_left_pads = {0, 0, 0};
std::vector<ck::index_t> dinput_right_pads = {0, 0, 0};
ck::index_t N = 1;
ck::index_t C = 16;
ck::index_t Di = 40;
ck::index_t Hi = 40;
ck::index_t Wi = 40;
pool3d_bwd_test<DevicePoolBwdInstance, DOutDataType, DInDataType, DOutLayout, DInLayout>(
true,
false,
N,
C,
Di,
Hi,
Wi,
window_lengths,
window_strides,
window_dilations,
dinput_left_pads,
dinput_right_pads);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp"
template <typename TensorLayout>
std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
ck::index_t C_,
ck::index_t D,
ck::index_t H,
ck::index_t W,
TensorLayout layout)
{
using namespace ck::literals;
(void)N_;
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCDHW>::value)
return {C_ * D * H * W, D * H * W, H * W, W, 1_uz};
else if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NDHWC>::value)
return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_};
};
template <typename TensorLayout>
HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
std::size_t C_,
std::size_t D,
std::size_t H,
std::size_t W,
TensorLayout layout)
{
using namespace ck::literals;
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCDHW>::value)
{
return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz});
}
else if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NDHWC>::value)
{
return HostTensorDescriptor({N_, C_, D, H, W},
{D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_});
}
};
template <typename DevicePoolBwdInstance,
typename DOutDataType,
typename DInDataType,
typename DOutLayout,
typename DInLayout>
bool pool3d_bwd_test(bool do_verification,
bool time_kernel,
ck::index_t N,
ck::index_t C,
ck::index_t Di,
ck::index_t Hi,
ck::index_t Wi,
std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> dinput_left_pads,
std::vector<ck::index_t> dinput_right_pads)
{
auto OutSpatialLength = [&](auto InSpatialLength, int index) {
ck::index_t left_pad = dinput_left_pads[index];
ck::index_t right_pad = dinput_right_pads[index];
ck::index_t window_len = window_lengths[index];
ck::index_t stride = window_strides[index];
ck::index_t dilation = window_dilations[index];
ck::index_t eff = (window_len - 1) * dilation + 1;
return (InSpatialLength + left_pad + right_pad - eff) / stride + 1;
};
ck::index_t Do = OutSpatialLength(Di, 0);
ck::index_t Ho = OutSpatialLength(Hi, 1);
ck::index_t Wo = OutSpatialLength(Wi, 2);
Tensor<DOutDataType> dout(f_host_tensor_descriptor(N, C, Do, Ho, Wo, DOutLayout{}));
Tensor<DInDataType> din_dev(f_host_tensor_descriptor(N, C, Di, Hi, Wi, DInLayout{}));
Tensor<DInDataType> din_host(f_host_tensor_descriptor(N, C, Di, Hi, Wi, DInLayout{}));
std::cout << "dout: " << dout.mDesc << std::endl;
std::cout << "din_host: " << din_host.mDesc << std::endl;
dout.GenerateTensorValue(GeneratorTensor_3<DOutDataType>{0.0, 1.0});
DeviceMem dout_device_buf(sizeof(DOutDataType) * dout.mDesc.GetElementSpaceSize());
DeviceMem din_device_buf(sizeof(DInDataType) * din_dev.mDesc.GetElementSpaceSize());
dout_device_buf.ToDevice(dout.mData.data());
din_device_buf.SetZero();
auto pool = DevicePoolBwdInstance{};
auto invoker_ptr = pool.MakeInvokerPointer();
auto argument_ptr =
pool.MakeArgumentPointer(static_cast<DOutDataType*>(dout_device_buf.GetDeviceBuffer()),
static_cast<DInDataType*>(din_device_buf.GetDeviceBuffer()),
{N, C, Do, Ho, Wo},
{N, C, Di, Hi, Wi},
f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, DOutLayout{}),
f_tensor_strides_ncdhw(N, C, Di, Hi, Wi, DInLayout{}),
window_lengths,
window_strides,
window_dilations,
dinput_left_pads,
dinput_right_pads);
if(!pool.IsSupportedArgument(argument_ptr.get()))
{
throw std::runtime_error("wrong! device_op with the specified compilation parameters does "
"not support this problem");
}
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::cout << "Perf: " << ave_time << std::endl;
bool pass = true;
if(do_verification)
{
auto ref_pool =
ck::tensor_operation::host::ReferenceAvgPoolBwd<3, DInDataType, DOutDataType>();
auto ref_invoker = ref_pool.MakeInvoker();
auto ref_argument = ref_pool.MakeArgument(din_host,
dout,
window_lengths,
window_strides,
window_dilations,
dinput_left_pads,
dinput_right_pads);
ref_invoker.Run(ref_argument);
din_device_buf.FromDevice(din_dev.mData.data());
pass = ck::utils::check_err(din_dev, din_host);
}
return pass;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp"
#include "avgpool3d_bwd_common.hpp"
using DOutDataType = ck::half_t;
using DInDataType = ck::half_t;
using ComputeDataType = float;
#if 1
using DOutLayout = ck::tensor_layout::convolution::NDHWC;
using DInLayout = ck::tensor_layout::convolution::NDHWC;
#else
using DOutLayout = ck::tensor_layout::convolution::NCDHW;
using DInLayout = ck::tensor_layout::convolution::NCDHW;
#endif
using DevicePoolBwdInstance =
ck::tensor_operation::device::DeviceAvgPool3dBwd_NDHWC_NDHWC<DOutDataType,
DInDataType,
ComputeDataType,
64, // BlockSize
64, // ReduceMThreadClusterSize
1, // ReduceKThreadClusterSize
1, // ReduceMThreadSliceSize
1, // ReduceKThreadSliceSize
1>; // InSrcOutDstVectorSize
int main()
{
std::vector<ck::index_t> window_lengths = {5, 5, 5};
std::vector<ck::index_t> window_strides = {2, 2, 2};
std::vector<ck::index_t> window_dilations = {2, 2, 2};
std::vector<ck::index_t> dinput_left_pads = {0, 0, 0};
std::vector<ck::index_t> dinput_right_pads = {0, 0, 0};
ck::index_t N = 1;
ck::index_t C = 16;
ck::index_t Di = 40;
ck::index_t Hi = 40;
ck::index_t Wi = 40;
pool3d_bwd_test<DevicePoolBwdInstance, DOutDataType, DInDataType, DOutLayout, DInLayout>(
true,
false,
N,
C,
Di,
Hi,
Wi,
window_lengths,
window_strides,
window_dilations,
dinput_left_pads,
dinput_right_pads);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp"
#include "avgpool3d_bwd_common.hpp"
using DOutDataType = float;
using DInDataType = float;
using ComputeDataType = float;
#if 1
using DOutLayout = ck::tensor_layout::convolution::NDHWC;
using DInLayout = ck::tensor_layout::convolution::NDHWC;
#else
using DOutLayout = ck::tensor_layout::convolution::NCDHW;
using DInLayout = ck::tensor_layout::convolution::NCDHW;
#endif
using DevicePoolBwdInstance =
ck::tensor_operation::device::DeviceAvgPool3dBwd_NDHWC_NDHWC<DOutDataType,
DInDataType,
ComputeDataType,
64, // BlockSize
64, // ReduceMThreadClusterSize
1, // ReduceKThreadClusterSize
1, // ReduceMThreadSliceSize
1, // ReduceKThreadSliceSize
1>; // InSrcOutDstVectorSize
int main()
{
std::vector<ck::index_t> window_lengths = {5, 5, 5};
std::vector<ck::index_t> window_strides = {2, 2, 2};
std::vector<ck::index_t> window_dilations = {2, 2, 2};
std::vector<ck::index_t> dinput_left_pads = {0, 0, 0};
std::vector<ck::index_t> dinput_right_pads = {0, 0, 0};
ck::index_t N = 1;
ck::index_t C = 16;
ck::index_t Di = 40;
ck::index_t Hi = 40;
ck::index_t Wi = 40;
pool3d_bwd_test<DevicePoolBwdInstance, DOutDataType, DInDataType, DOutLayout, DInLayout>(
true,
false,
N,
C,
Di,
Hi,
Wi,
window_lengths,
window_strides,
window_dilations,
dinput_left_pads,
dinput_right_pads);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t NDimSpatial,
typename DOutDataType,
typename DInDataType,
typename DOutLayout,
typename DInLayout>
struct DeviceAvgPoolBwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_dout,
void* p_din,
std::vector<ck::index_t> dout_n_k_wos_lengths,
std::vector<ck::index_t> dout_n_k_wos_strides,
std::vector<ck::index_t> din_n_k_wos_length,
std::vector<ck::index_t> din_n_k_wos_strides,
std::vector<ck::index_t> window_k_c_xs_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -27,15 +27,12 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator ...@@ -27,15 +27,12 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
MakeArgumentPointer(const void* p_in, MakeArgumentPointer(const void* p_in,
void* p_wei, void* p_wei,
const void* p_out, const void* p_out,
const ck::index_t G, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -123,7 +123,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -123,7 +123,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
ALayout, ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, // TODO: distinguish A/B datatype ADataType,
BDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
...@@ -284,8 +285,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -284,8 +285,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, ADataType, CDataType, true>; ADataType,
BDataType,
CDataType,
true>;
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -357,8 +361,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -357,8 +361,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
} }
else else
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, ADataType, CDataType, false>; ADataType,
BDataType,
CDataType,
false>;
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
kernel, kernel,
......
...@@ -65,7 +65,8 @@ template <typename ALayout, ...@@ -65,7 +65,8 @@ template <typename ALayout,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeType = CDataType>
struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -87,7 +88,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -87,7 +88,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
ALayout, ALayout,
BLayout, BLayout,
CLayout, CLayout,
ADataType, // TODO: distinguish A/B datatype ADataType,
BDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
CDataType, CDataType,
...@@ -128,7 +130,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -128,7 +130,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
PipelineVer>; PipelineVer,
ComputeType>;
using Argument = typename GridwiseGemm::Argument; using Argument = typename GridwiseGemm::Argument;
......
...@@ -784,15 +784,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -784,15 +784,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
Argument(const InDataType* p_in_grid, Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const ck::index_t G, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& /*a_g_n_c_wis_strides*/,
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& /*b_g_k_c_xs_strides*/,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& /*e_g_n_k_wos_strides*/,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/,
const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -812,27 +809,38 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -812,27 +809,38 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
a_element_op_{out_element_op}, a_element_op_{out_element_op},
b_element_op_{wei_element_op}, b_element_op_{wei_element_op},
c_element_op_{in_element_op}, c_element_op_{in_element_op},
Conv_G_{G}, Conv_G_{a_g_n_c_wis_lengths[0]},
Conv_N_{N}, Conv_N_{a_g_n_c_wis_lengths[1]},
Conv_K_{K}, Conv_K_{b_g_k_c_xs_lengths[1]},
Conv_C_{C}, Conv_C_{a_g_n_c_wis_lengths[2]},
input_spatial_lengths_{input_spatial_lengths}, input_spatial_lengths_{},
filter_spatial_lengths_{filter_spatial_lengths}, filter_spatial_lengths_{},
output_spatial_lengths_{output_spatial_lengths}, output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations}, conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}, input_right_pads_{input_right_pads},
k_batch_{split_k} k_batch_{split_k}
{ {
constexpr index_t spatial_offset = 3;
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
end(a_g_n_c_wis_lengths),
begin(input_spatial_lengths_));
std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset,
end(b_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset,
end(e_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>( DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
N, Conv_N_,
K, Conv_K_,
C, Conv_C_,
input_spatial_lengths, input_spatial_lengths_,
filter_spatial_lengths, filter_spatial_lengths_,
output_spatial_lengths, output_spatial_lengths_,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -856,21 +864,21 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -856,21 +864,21 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// A/B/C Batch Stride // A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = compute_ptr_offset_of_batch_.BatchStrideA_ =
N * K * Conv_N_ * Conv_K_ *
std::accumulate(begin(output_spatial_lengths), std::accumulate(begin(output_spatial_lengths_),
end(output_spatial_lengths), end(output_spatial_lengths_),
index_t{1}, index_t{1},
std::multiplies<>{}); std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideB_ = compute_ptr_offset_of_batch_.BatchStrideB_ =
N * C * Conv_N_ * Conv_C_ *
std::accumulate(begin(input_spatial_lengths), std::accumulate(begin(input_spatial_lengths_),
end(input_spatial_lengths), end(input_spatial_lengths_),
index_t{1}, index_t{1},
std::multiplies<>{}); std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ = compute_ptr_offset_of_batch_.BatchStrideC_ =
K * C * Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths), std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths), end(filter_spatial_lengths_),
index_t{1}, index_t{1},
std::multiplies<>{}); std::multiplies<>{});
} }
...@@ -904,9 +912,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -904,9 +912,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const index_t Conv_K_; const index_t Conv_K_;
const index_t Conv_C_; const index_t Conv_C_;
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths_; std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_; std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_; const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_; const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_; const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
...@@ -1110,18 +1118,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1110,18 +1118,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(const InDataType* p_in_grid, static auto
MakeArgument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const ck::index_t G, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -1134,15 +1140,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1134,15 +1140,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
return Argument{p_in_grid, return Argument{p_in_grid,
p_wei_grid, p_wei_grid,
p_out_grid, p_out_grid,
G, a_g_n_c_wis_lengths, // input
N, a_g_n_c_wis_strides,
K, b_g_k_c_xs_lengths, // weight
C, b_g_k_c_xs_strides,
input_spatial_lengths, e_g_n_k_wos_lengths, // output
filter_spatial_lengths, e_g_n_k_wos_strides,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -1159,15 +1162,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1159,15 +1162,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid, void* p_wei_grid,
const void* p_out_grid, const void* p_out_grid,
const ck::index_t G, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -1180,15 +1180,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1180,15 +1180,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid), return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<WeiDataType*>(p_wei_grid), static_cast<WeiDataType*>(p_wei_grid),
static_cast<const OutDataType*>(p_out_grid), static_cast<const OutDataType*>(p_out_grid),
G, a_g_n_c_wis_lengths, // input
N, a_g_n_c_wis_strides,
K, b_g_k_c_xs_lengths, // weight
C, b_g_k_c_xs_strides,
input_spatial_lengths, e_g_n_k_wos_lengths, // output
filter_spatial_lengths, e_g_n_k_wos_strides,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment