Unverified Commit 1ee99dca authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Support NHWGC conv2d_bwd_weight (#769)



* Support NHWGC conv2d_bwd_weight

* Fix client example

* Fix client example

* Fix comments

* Redesign grouped_conv_bwd_weight instances

* Clang format fix

---------
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
parent 87f2bbcf
...@@ -101,13 +101,15 @@ template <ck::index_t NumDimSpatial, ...@@ -101,13 +101,15 @@ template <ck::index_t NumDimSpatial,
typename WeiLayout, typename WeiLayout,
typename OutLayout> typename OutLayout>
bool run_grouped_conv_bwd_weight( bool run_grouped_conv_bwd_weight(
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths, const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths, const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths, const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial + 3>& input_strides,
const std::array<ck::index_t, NumDimSpatial + 3>& output_strides,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NumDimSpatial>& input_left_pads, const std::array<ck::index_t, NumDimSpatial>& input_left_pads,
...@@ -157,6 +159,8 @@ bool run_grouped_conv_bwd_weight( ...@@ -157,6 +159,8 @@ bool run_grouped_conv_bwd_weight(
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -224,6 +228,8 @@ bool run_grouped_conv_bwd_weight( ...@@ -224,6 +228,8 @@ bool run_grouped_conv_bwd_weight(
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -22,6 +22,15 @@ static constexpr ck::index_t C = 192; ...@@ -22,6 +22,15 @@ static constexpr ck::index_t C = 192;
static constexpr ck::index_t X = 3; static constexpr ck::index_t X = 3;
static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Wi = 28;
static constexpr ck::index_t Wo = 28; static constexpr ck::index_t Wo = 28;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1};
int main() int main()
{ {
...@@ -31,7 +40,19 @@ int main() ...@@ -31,7 +40,19 @@ int main()
OutDataType, OutDataType,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>(G, N, K, C, {Wi}, {X}, {Wo}, {1}, {1}, {1}, {1}) OutLayout>(G,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)
? EXIT_SUCCESS ? EXIT_SUCCESS
: EXIT_FAILURE; : EXIT_FAILURE;
} }
...@@ -25,6 +25,17 @@ static constexpr ck::index_t Hi = 28; ...@@ -25,6 +25,17 @@ static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Wi = 28;
static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 28; static constexpr ck::index_t Wo = 28;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Hi * Wi * C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Ho * Wo * K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1};
int main() int main()
{ {
...@@ -34,8 +45,19 @@ int main() ...@@ -34,8 +45,19 @@ int main()
OutDataType, OutDataType,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>( OutLayout>(G,
G, N, K, C, {Hi, Wi}, {Y, X}, {Ho, Wo}, {1, 1}, {1, 1}, {1, 1}, {1, 1}) N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)
? EXIT_SUCCESS ? EXIT_SUCCESS
: EXIT_FAILURE; : EXIT_FAILURE;
} }
...@@ -28,6 +28,17 @@ static constexpr ck::index_t Wi = 3; ...@@ -28,6 +28,17 @@ static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28; static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3; static constexpr ck::index_t Wo = 3;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1, 1};
int main() int main()
{ {
...@@ -41,13 +52,15 @@ int main() ...@@ -41,13 +52,15 @@ int main()
N, N,
K, K,
C, C,
{Di, Hi, Wi}, input_spatial_lengths,
{Z, Y, X}, filter_spatial_lengths,
{Do, Ho, Wo}, output_spatial_lengths,
{1, 1, 1}, input_strides,
{1, 1, 1}, output_strides,
{1, 1, 1}, conv_filter_strides,
{1, 1, 1}) conv_filter_dilations,
input_left_pads,
input_right_pads)
? EXIT_SUCCESS ? EXIT_SUCCESS
: EXIT_FAILURE; : EXIT_FAILURE;
} }
...@@ -28,6 +28,17 @@ static constexpr ck::index_t Wi = 3; ...@@ -28,6 +28,17 @@ static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28; static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3; static constexpr ck::index_t Wo = 3;
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi};
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X};
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1, 1};
int main() int main()
{ {
...@@ -37,17 +48,20 @@ int main() ...@@ -37,17 +48,20 @@ int main()
OutDataType, OutDataType,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>(G, OutLayout>(
N, G,
K, N,
C, K,
{Di, Hi, Wi}, C,
{Z, Y, X}, {Di, Hi, Wi},
{Do, Ho, Wo}, {Z, Y, X},
{1, 1, 1}, {Do, Ho, Wo},
{1, 1, 1}, {N * Di * Hi * Wi * C, Di * Hi * Wi * C, Hi * Wi * C, Wi * C, C, 1},
{1, 1, 1}, {N * Do * Ho * Wo * K, Do * Ho * Wo * K, Ho * Wo * K, Wo * K, K, 1},
{1, 1, 1}) {1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1})
? EXIT_SUCCESS ? EXIT_SUCCESS
: EXIT_FAILURE; : EXIT_FAILURE;
} }
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
using InDataType = BF16; using InDataType = BF16;
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory // bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
...@@ -17,8 +17,20 @@ using OutElementOp = PassThrough; ...@@ -17,8 +17,20 @@ using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance = using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
NDimSpatial, // NDimSpatial NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWC,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GNDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GKZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWK,
ck::tensor_layout::convolution::GNHWK,
ck::tensor_layout::convolution::GNDHWK>>,
InDataType, // InDataType InDataType, // InDataType
WeiDataType, // WeiDataType WeiDataType, // WeiDataType
OutDataType, // OutDataType OutDataType, // OutDataType
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "common.hpp" #include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
using InDataType = F16; using InDataType = F16;
using WeiDataType = F16; using WeiDataType = F16;
...@@ -16,8 +16,20 @@ using OutElementOp = PassThrough; ...@@ -16,8 +16,20 @@ using OutElementOp = PassThrough;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using DeviceConvBwdWeightInstance = using DeviceConvBwdWeightInstance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
NDimSpatial, // NDimSpatial NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWC,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GNDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GKZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWK,
ck::tensor_layout::convolution::GNHWK,
ck::tensor_layout::convolution::GNDHWK>>,
InDataType, // InDataType InDataType, // InDataType
WeiDataType, // WeiDataType WeiDataType, // WeiDataType
OutDataType, // OutDataType OutDataType, // OutDataType
......
...@@ -75,6 +75,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -75,6 +75,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{}; std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
std::array<ck::index_t, NDimSpatial + 3> input_strides{};
std::array<ck::index_t, NDimSpatial + 3> output_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{}; std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{}; std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{}; std::array<ck::index_t, NDimSpatial> input_left_pads{};
...@@ -85,6 +87,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -85,6 +87,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths)); range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths));
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths)); range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths)); range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
range_copy(conv_param.input_left_pads_, begin(input_left_pads)); range_copy(conv_param.input_left_pads_, begin(input_left_pads));
...@@ -103,6 +107,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, ...@@ -103,6 +107,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -27,17 +27,19 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator ...@@ -27,17 +27,19 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
MakeArgumentPointer(const void* p_in, MakeArgumentPointer(const void* p_in,
void* p_wei, void* p_wei,
const void* p_out, const void* p_out,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
......
...@@ -195,17 +195,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -195,17 +195,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
ck::index_t batch_k) const ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -347,17 +347,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -347,17 +347,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
} // function end } // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
ck::index_t batch_k) const ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -515,17 +515,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -515,17 +515,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& input_right_pads,
ck::index_t batch_k) const ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -784,17 +784,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -784,17 +784,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
Argument(const InDataType* p_in_grid, Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -897,18 +899,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -897,18 +899,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
InElementwiseOperation c_element_op_; InElementwiseOperation c_element_op_;
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
index_t Conv_G_; const index_t Conv_G_;
index_t Conv_N_; const index_t Conv_N_;
index_t Conv_K_; const index_t Conv_K_;
index_t Conv_C_; const index_t Conv_C_;
std::array<ck::index_t, NDimSpatial> input_spatial_lengths_; const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_; const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_; const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> conv_filter_strides_; const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
std::array<ck::index_t, NDimSpatial> conv_filter_dilations_; const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
std::array<ck::index_t, NDimSpatial> input_left_pads_; const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
std::array<ck::index_t, NDimSpatial> input_right_pads_; const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
index_t k_batch_; index_t k_batch_;
}; };
...@@ -1111,17 +1113,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1111,17 +1113,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static auto MakeArgument(const InDataType* p_in_grid, static auto MakeArgument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -1137,6 +1141,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1137,6 +1141,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -1153,17 +1159,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1153,17 +1159,19 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid, void* p_wei_grid,
const void* p_out_grid, const void* p_out_grid,
ck::index_t G, const ck::index_t G,
ck::index_t N, const ck::index_t N,
ck::index_t K, const ck::index_t K,
ck::index_t C, const ck::index_t C,
std::array<ck::index_t, NDimSpatial> input_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
std::array<ck::index_t, NDimSpatial> output_spatial_lengths, const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
std::array<ck::index_t, NDimSpatial> conv_filter_strides, const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
std::array<ck::index_t, NDimSpatial> conv_filter_dilations, const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
std::array<ck::index_t, NDimSpatial> input_left_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
std::array<ck::index_t, NDimSpatial> input_right_pads, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -1179,6 +1187,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1179,6 +1187,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
input_spatial_lengths, input_spatial_lengths,
filter_spatial_lengths, filter_spatial_lengths,
output_spatial_lengths, output_spatial_lengths,
input_strides,
output_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
......
...@@ -91,6 +91,42 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( ...@@ -91,6 +91,42 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
F32,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F32,
F32,
F32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
// conv3d backward weight // conv3d backward weight
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3, std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
...@@ -162,66 +198,103 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -162,66 +198,103 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, GNWC> && if constexpr(NumDimSpatial == 1)
is_same_v<WeiLayout, GKXC> && is_same_v<OutLayout, GNWK>)
{ {
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> && if constexpr(is_same_v<InLayout, GNWC> && is_same_v<WeiLayout, GKXC> &&
is_same_v<OutDataType, float>) is_same_v<OutLayout, GNWK>)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{ {
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
op_ptrs); is_same_v<OutDataType, float>)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
op_ptrs);
}
} }
} }
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> && else if constexpr(NumDimSpatial == 2)
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
{ {
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> && if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutDataType, float>) is_same_v<OutLayout, GNHWK>)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{ {
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
op_ptrs);
}
} }
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> && else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutDataType, ck::bhalf_t>) is_same_v<OutLayout, NHWGK>)
{ {
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
op_ptrs); is_same_v<OutDataType, float>)
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
op_ptrs);
}
} }
} }
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> && else if constexpr(NumDimSpatial == 3)
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
{ {
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> && if(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutDataType, float>) is_same_v<OutLayout, GNDHWK>)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{ {
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
op_ptrs); is_same_v<OutDataType, float>)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
op_ptrs);
}
} }
} }
......
...@@ -2,5 +2,8 @@ add_instance_library(device_grouped_conv2d_bwd_weight_instance ...@@ -2,5 +2,8 @@ add_instance_library(device_grouped_conv2d_bwd_weight_instance
device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment