Unverified Commit 70a814f1 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Refactor transform conv to gemm fwd (#1391)

* Refactor transform conv to gemm fwd

* fixes codegen

* wmma fixes

* fix wmma

* Fix copyright
parent ab250afd
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
......@@ -95,16 +98,27 @@ auto transform_conv(ck::index_t num_dim,
ck::Array<ck::index_t, 5> out_lengths,
ck::Array<ck::index_t, 5> out_strides)
{
ck::Array<ck::index_t, 5> dummy_dims;
ck::Array<ck::index_t, 2> dummy_spatial_dims;
if(num_dim == 2 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
{
ck::tensor_operation::TransformConvFwdToGemm<
2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
conv_fwd;
conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
return res.transform_func(conv_fwd);
}
if(num_dim == 2 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
......@@ -112,10 +126,19 @@ auto transform_conv(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm<
2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
conv_fwd;
conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
return res.transform_func(conv_fwd);
}
if(num_dim == 2 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
......@@ -123,20 +146,38 @@ auto transform_conv(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm<
2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
conv_fwd;
conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
return res.transform_func(conv_fwd);
}
if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
{
ck::tensor_operation::TransformConvFwdToGemm<
2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
conv_fwd;
conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
return res.transform_func(conv_fwd);
}
throw std::runtime_error("Incorrect conv spec");
}
......@@ -146,16 +187,28 @@ auto transform_conv_3d(ck::index_t num_dim,
ck::Array<ck::index_t, 6> out_lengths,
ck::Array<ck::index_t, 6> out_strides)
{
ck::Array<ck::index_t, 6> dummy_dims;
ck::Array<ck::index_t, 3> dummy_spatial_dims;
if(num_dim == 3 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
{
ck::tensor_operation::TransformConvFwdToGemm<
3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
conv_fwd;
conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
return res.transform_func(conv_fwd);
}
if(num_dim == 3 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
......@@ -163,10 +216,19 @@ auto transform_conv_3d(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm<
3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
conv_fwd;
conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
return res.transform_func(conv_fwd);
}
if(num_dim == 3 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
......@@ -174,20 +236,38 @@ auto transform_conv_3d(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm<
3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
conv_fwd;
conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
return res.transform_func(conv_fwd);
}
if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
{
ck::tensor_operation::TransformConvFwdToGemm<
3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
conv_fwd;
conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
return res.transform_func(conv_fwd);
}
throw std::runtime_error("Incorrect conv spec");
}
......@@ -197,16 +277,28 @@ auto transform_conv_1d(ck::index_t num_dim,
ck::Array<ck::index_t, 4> out_lengths,
ck::Array<ck::index_t, 4> out_strides)
{
ck::Array<ck::index_t, 4> dummy_dims;
ck::Array<ck::index_t, 1> dummy_spatial_dims;
if(num_dim == 1 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
{
ck::tensor_operation::TransformConvFwdToGemm<
1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
conv_fwd;
conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
return res.transform_func(conv_fwd);
}
if(num_dim == 1 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
......@@ -214,10 +306,19 @@ auto transform_conv_1d(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm<
1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
conv_fwd;
conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
return res.transform_func(conv_fwd);
}
if(num_dim == 1 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
......@@ -225,20 +326,38 @@ auto transform_conv_1d(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm<
1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
conv_fwd;
conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
return res.transform_func(conv_fwd);
}
if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
{
ck::tensor_operation::TransformConvFwdToGemm<
1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
conv_fwd;
conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd);
return res.transform_func(conv_fwd);
}
throw std::runtime_error("Incorrect dims or conv spec");
}
......
......@@ -359,36 +359,17 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
__host__ __device__ static auto
MakeAGridDescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads)
MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -398,12 +379,10 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename BLay>
__host__ __device__ static auto
MakeBGridDescriptor_N_K(const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
......@@ -413,12 +392,10 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename ELay>
__host__ __device__ static auto
MakeEGridDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -428,26 +405,27 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast.
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(e_g_n_k_wos_lengths,
ds_g_n_k_wos_strides[i]);
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
},
Number<NumDTensor>{});
}
// desc for problem definition
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K =
remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it
......@@ -533,21 +511,23 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides)},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
a_grid_desc_m_k_{
DeviceOp::MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_n_k_{
DeviceOp::MakeBGridDescriptor_N_K<BLayout>(conv_to_gemm_transformer_)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
e_grid_desc_m_n_{
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
......@@ -637,9 +617,20 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
ds_g_n_k_wos_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]);
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
});
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
......@@ -694,6 +685,9 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
......
......@@ -8,7 +8,6 @@
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
......@@ -65,8 +64,8 @@ struct DeviceColumnToImageImpl
static constexpr auto spatial_offset = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>{};
using GemmToConvFwdTransformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{
MPerBlock, 0 /* NPerBlock*/, KPerBlock};
......@@ -234,21 +233,21 @@ struct DeviceColumnToImageImpl
: independent_filter_stride;
}
GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
image_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor
c_g_n_k_wos_lengths,
{}, // not needed for A Descriptor
// conv_filter_strides,
independent_filter_strides,
conv_filter_dilations,
input_left_pads_with_offset,
input_right_pads};
// Calculate image form descriptor for the modified convolution problem
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>(
a_g_n_c_wis_lengths,
image_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor
c_g_n_k_wos_lengths,
{}, // not needed for A Descriptor
// conv_filter_strides,
independent_filter_strides,
conv_filter_dilations,
input_left_pads_with_offset,
input_right_pads,
N);
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......
......@@ -238,37 +238,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_g_n_c_wis_lengths[I1]);
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -286,12 +266,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
template <typename BLay>
static auto
MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
......@@ -309,13 +287,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
}
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -323,27 +298,27 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
return out_gemmm_gemmn_desc;
}
static auto MakeDsGridDescriptor_M_N(
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(ds_g_n_k_wos_lengths[i],
ds_g_n_k_wos_strides[i]);
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
},
Number<NumDTensor>{});
}
// desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
dummy_conv_to_gemm_transformer))>;
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
// GridwiseGemm
using GridwiseGemm =
......@@ -426,21 +401,22 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(conv_to_gemm_transformer_)},
e_grid_desc_m_n_{
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{},
ds_grid_desc_m0_m10_m11_n0_n10_n11_{},
......@@ -471,6 +447,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths[i],
ds_g_n_k_wos_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
......@@ -478,8 +465,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
// D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]);
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
});
// populate desc for Ds/E
......@@ -523,6 +510,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
DsGridDesc_M_N ds_grid_desc_m_n_;
......
......@@ -234,37 +234,17 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
c_g_n_k_wos_lengths,
c_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_g_n_c_wis_lengths[I1]);
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -283,12 +263,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
template <typename BLay>
static auto
MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
......@@ -306,13 +284,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
}
template <typename CLay>
static auto
MakeCGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
static auto MakeCGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(
c_g_n_k_wos_lengths, c_g_n_k_wos_strides, c_g_n_k_wos_lengths[I1]);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -321,11 +296,13 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
}
// desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N<CLayout>({}, {}))>;
dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
dummy_conv_to_gemm_transformer))>;
using CGridDesc_M_N =
remove_cvref_t<decltype(MakeCGridDescriptor_M_N<CLayout>(dummy_conv_to_gemm_transformer))>;
// GridwiseGemm
using GridwiseGemm =
......@@ -396,21 +373,22 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
p_b_grid_{static_cast<const BDataType*>(p_b)},
p_c_grid_{static_cast<CDataType*>(p_c)},
num_group_{a_g_n_c_wis_lengths[0]},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
c_g_n_k_wos_lengths,
c_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N<CLayout>(c_g_n_k_wos_lengths,
c_g_n_k_wos_strides)},
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(conv_to_gemm_transformer_)},
c_grid_desc_m_n_{
DeviceOp::MakeCGridDescriptor_M_N<CLayout>(conv_to_gemm_transformer_)},
a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{},
c_grid_desc_m0_m10_m11_n0_n10_n11_{},
......@@ -473,6 +451,9 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
......
......@@ -316,38 +316,21 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization, NumGroupsToMerge>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ALayout,
ELayout,
NumGroupsToMerge>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t Conv_N)
static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
Conv_N);
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -356,13 +339,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
template <typename BLay>
static auto
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
......@@ -371,14 +351,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const index_t Conv_N)
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -388,27 +364,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast.
static auto MakeDsGridDescriptor_M_N(
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const index_t Conv_N)
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], Conv_N);
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
},
Number<NumDTensor>{});
}
// desc for problem definition
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, 1))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}, 1))>;
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K =
remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it
......@@ -496,28 +472,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
conv_N_per_block_{
conv_to_gemm_transformer.template GetSplitedNSize<ADataType, EDataType>(
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
conv_N_per_block_)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides)},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
conv_N_per_block_{conv_to_gemm_transformer_.N_},
a_grid_desc_m_k_{
DeviceOp::MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_n_k_{
DeviceOp::MakeBGridDescriptor_N_K<BLayout>(conv_to_gemm_transformer_)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)},
e_grid_desc_m_n_{
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
......@@ -623,9 +595,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_;
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
ds_g_n_k_wos_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], conv_N_per_block_);
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
});
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
......@@ -690,6 +673,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
index_t conv_N_per_block_;
AGridDesc_M_K a_grid_desc_m_k_;
......
......@@ -293,39 +293,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
EDataType>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t Conv_N)
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
Conv_N);
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -344,12 +327,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <typename BLay>
static auto
MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
......@@ -367,15 +348,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const index_t Conv_N)
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -384,7 +361,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
// desc for problem definition
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}, 1))>;
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
#define GridwiseGemmV3TemplateParams \
tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \
......@@ -417,9 +396,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
dummy_conv_to_gemm_transformer))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
......@@ -450,27 +429,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
p_b_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
conv_N_per_block_{
conv_to_gemm_transformer.template GetSplitedNSize<ADataType, EDataType>(
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_ak0_m_ak1_{MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
conv_N_per_block_)},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
conv_N_per_block_{conv_to_gemm_transformer_.N_},
a_grid_desc_ak0_m_ak1_{
MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_bk0_n_bk1_{
MakeBGridDescriptor_BK0_N_BK1<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)},
MakeBGridDescriptor_BK0_N_BK1<BLayout>(conv_to_gemm_transformer_)},
e_grid_desc_m_n_{
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{},
......@@ -519,6 +494,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
index_t conv_N_per_block_;
// tensor descriptors for block/thread-wise copy
......
......@@ -309,37 +309,16 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_g_n_c_wis_lengths[I1]);
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -348,13 +327,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
}
template <typename BLay>
static auto
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
......@@ -363,13 +339,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
}
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -447,11 +420,14 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
}
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<DELayout>({}, {}))>;
using RGridDesc_M = remove_cvref_t<decltype(MakeRGridDescriptor_M<RLayout>({}, {}))>;
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K =
remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<DELayout>(dummy_conv_to_gemm_transformer))>;
using RGridDesc_M = remove_cvref_t<decltype(MakeRGridDescriptor_M<RLayout>({}, {}))>;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
......@@ -551,21 +527,23 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
p_rs_grid_{}, // FIXME
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides)},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
a_grid_desc_m_k_{
DeviceOp::MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_n_k_{
DeviceOp::MakeBGridDescriptor_N_K<BLayout>(conv_to_gemm_transformer_)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<DELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
e_grid_desc_m_n_{
DeviceOp::MakeEGridDescriptor_M_N<DELayout>(conv_to_gemm_transformer_)},
r_grid_desc_m_{
DeviceOp::MakeRGridDescriptor_M<RLayout>(r_g_n_wos_lengths, r_g_n_wos_strides)},
a_grid_desc_ak0_m_ak1_{
......@@ -621,9 +599,20 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
// D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths[i],
ds_g_n_k_wos_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DELayout>(
ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]);
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DELayout>(conv_to_gemm_transformer_d);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......@@ -660,6 +649,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
EDataType* p_e_grid_;
typename GridwiseGemm::RsGridPointer p_rs_grid_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
......
......@@ -135,36 +135,16 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static constexpr auto BEnableLds =
BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1);
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto MakeAGridDescriptor(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
static auto MakeAGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_g_n_c_wis_lengths[I1]);
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -205,12 +185,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
template <typename BLay>
static auto MakeBGridDescriptor(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
static auto MakeBGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
......@@ -251,13 +229,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -265,26 +240,27 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
return out_gemmm_gemmn_desc;
}
static auto MakeDsGridDescriptor_M_N(
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(ds_g_n_k_wos_lengths[i],
ds_g_n_k_wos_strides[i]);
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
},
Number<NumDTensor>{});
}
// desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc =
decltype(DeviceOp::MakeAGridDescriptor<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}));
using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor<BLayout>({}, {}));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
decltype(DeviceOp::MakeAGridDescriptor<ALayout>(dummy_conv_to_gemm_transformer));
using BGridDesc =
decltype(DeviceOp::MakeBGridDescriptor<BLayout>(dummy_conv_to_gemm_transformer));
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
// GridwiseOp
using GridwiseOp = GridwiseGemmMultipleD_Wmma<
......@@ -373,21 +349,21 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_{DeviceOp::MakeAGridDescriptor<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_{
DeviceOp::MakeBGridDescriptor<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
e_grid_desc_m_n_{
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
a_grid_desc_{DeviceOp::MakeAGridDescriptor<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_{DeviceOp::MakeBGridDescriptor<BLayout>(conv_to_gemm_transformer_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)},
......@@ -426,8 +402,24 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
});
// D desc
ds_grid_desc_m_n_ =
DeviceOp::MakeDsGridDescriptor_M_N(ds_g_n_k_wos_lengths, ds_g_n_k_wos_strides);
ds_grid_desc_m_n_ = generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths[i],
ds_g_n_k_wos_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
},
Number<NumDTensor>{});
// populate desc for Ds/E
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
......@@ -455,6 +447,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
......
......@@ -57,8 +57,8 @@ struct DeviceImageToColumnImpl
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>{};
using GemmToConvFwdTransformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{
......@@ -97,19 +97,19 @@ struct DeviceImageToColumnImpl
b_g_k_c_xs_lengths[I2] = C;
c_g_n_k_wos_lengths[I1] = N;
GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
image_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor
c_g_n_k_wos_lengths,
{}, // not needed for A Descriptor
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>(
a_g_n_c_wis_lengths,
image_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor
c_g_n_k_wos_lengths,
{}, // not needed for A Descriptor
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
N);
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......
......@@ -14,22 +14,15 @@
namespace ck {
namespace tensor_operation {
// function to be used on device, emulates std::accumulate
template <typename T, typename ForwardIterator, typename Size>
__host__ __device__ auto mult_accumulate_n(ForwardIterator first, Size count, T init)
{
for(ForwardIterator x = first; x != first + count; x++)
{
init *= *x;
}
return init;
}
template <index_t NDimSpatial,
device::ConvolutionForwardSpecialization ConvForwardSpecialization,
bool SplitN = false,
typename ADataType = float,
typename CDataType = float,
index_t NumGroupsToMerge = 1>
struct TransformConvFwdToGemm
{
private:
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
......@@ -37,10 +30,10 @@ struct TransformConvFwdToGemm
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static long_index_t
calculate_element_space_size_impl(const std::array<index_t, NDimSpatial + 3>& lengths,
const std::array<index_t, NDimSpatial + 3>& strides,
index_t i)
template <typename ConvDimsType>
static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
const ConvDimsType& strides,
index_t i)
{
long_index_t acc = 1;
for(; i < (NDimSpatial + 3); i++)
......@@ -52,11 +45,11 @@ struct TransformConvFwdToGemm
return acc;
}
template <typename ADataType, typename CDataType>
static index_t GetSplitedNSize(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
template <typename ConvDimsType>
static index_t GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides)
{
const long_index_t a_element_space_size =
calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
......@@ -102,6 +95,216 @@ struct TransformConvFwdToGemm
}
}
public:
__host__ __device__ constexpr TransformConvFwdToGemm() {}
template <typename ConvDimsType,
typename ConvSpatialDimsType,
index_t NDim = NDimSpatial,
typename std::enable_if<NDim == 1, bool>::type = false>
__host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& b_g_k_c_xs_lengths,
const ConvDimsType& b_g_k_c_xs_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides,
const ConvSpatialDimsType& conv_filter_strides,
const ConvSpatialDimsType& conv_filter_dilations,
const ConvSpatialDimsType& input_left_pads,
const ConvSpatialDimsType& input_right_pads)
: Di_{I1},
Hi_{I1},
Wi_{a_g_n_c_wis_lengths[I3]},
Do_{I1},
Ho_{I1},
Wo_{c_g_n_k_wos_lengths[I3]},
Z_{I1},
Y_{I1},
X_{b_g_k_c_xs_lengths[I3]},
K_{c_g_n_k_wos_lengths[I2]},
C_{b_g_k_c_xs_lengths[I2]},
DiStride_{I1},
HiStride_{I1},
WiStride_{a_g_n_c_wis_strides[I3]},
WoStride_{c_g_n_k_wos_strides[I3]},
XStride_{b_g_k_c_xs_strides[I3]},
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
CStrideTensorB_{b_g_k_c_xs_strides[I2]},
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
ConvStrideD_{I1},
ConvStrideH_{I1},
ConvStrideW_{conv_filter_strides[I0]},
ConvDilationD_{I1},
ConvDilationH_{I1},
ConvDilationW_{conv_filter_dilations[I0]},
InLeftPadD_{I0},
InLeftPadH_{I0},
InLeftPadW_{input_left_pads[I0]},
InRightPadD_{I0},
InRightPadH_{I0},
InRightPadW_{input_right_pads[I0]},
ZYX_{X_}
{
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
if constexpr(SplitN)
{
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
}
else
{
N_ = c_g_n_k_wos_lengths[I1];
}
NDoHoWo_ = N_ * Wo_;
}
template <typename ConvDimsType,
typename ConvSpatialDimsType,
index_t NDim = NDimSpatial,
typename std::enable_if<NDim == 2, bool>::type = false>
__host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& b_g_k_c_xs_lengths,
const ConvDimsType& b_g_k_c_xs_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides,
const ConvSpatialDimsType& conv_filter_strides,
const ConvSpatialDimsType& conv_filter_dilations,
const ConvSpatialDimsType& input_left_pads,
const ConvSpatialDimsType& input_right_pads)
: Di_{I1},
Hi_{a_g_n_c_wis_lengths[I3]},
Wi_{a_g_n_c_wis_lengths[I4]},
Do_{I1},
Ho_{c_g_n_k_wos_lengths[I3]},
Wo_{c_g_n_k_wos_lengths[I4]},
Z_{I1},
Y_{b_g_k_c_xs_lengths[I3]},
X_{b_g_k_c_xs_lengths[I4]},
K_{c_g_n_k_wos_lengths[I2]},
C_{b_g_k_c_xs_lengths[I2]},
DiStride_{I1},
HiStride_{a_g_n_c_wis_strides[I3]},
WiStride_{a_g_n_c_wis_strides[I4]},
WoStride_{c_g_n_k_wos_strides[I4]},
XStride_{b_g_k_c_xs_strides[I4]},
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
CStrideTensorB_{b_g_k_c_xs_strides[I2]},
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
ConvStrideD_{I1},
ConvStrideH_{conv_filter_strides[I0]},
ConvStrideW_{conv_filter_strides[I1]},
ConvDilationD_{I1},
ConvDilationH_{conv_filter_dilations[I0]},
ConvDilationW_{conv_filter_dilations[I1]},
InLeftPadD_{I0},
InLeftPadH_{input_left_pads[I0]},
InLeftPadW_{input_left_pads[I1]},
InRightPadD_{I0},
InRightPadH_{input_right_pads[I0]},
InRightPadW_{input_right_pads[I1]},
ZYX_{Y_ * X_}
{
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
if constexpr(SplitN)
{
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
}
else
{
N_ = c_g_n_k_wos_lengths[I1];
}
NDoHoWo_ = N_ * Ho_ * Wo_;
}
template <typename ConvDimsType,
typename ConvSpatialDimsType,
index_t NDim = NDimSpatial,
typename std::enable_if<NDim == 3, bool>::type = false>
__host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& b_g_k_c_xs_lengths,
const ConvDimsType& b_g_k_c_xs_strides,
const ConvDimsType& c_g_n_k_wos_lengths,
const ConvDimsType& c_g_n_k_wos_strides,
const ConvSpatialDimsType& conv_filter_strides,
const ConvSpatialDimsType& conv_filter_dilations,
const ConvSpatialDimsType& input_left_pads,
const ConvSpatialDimsType& input_right_pads)
: Di_{a_g_n_c_wis_lengths[I3]},
Hi_{a_g_n_c_wis_lengths[I4]},
Wi_{a_g_n_c_wis_lengths[I5]},
Do_{c_g_n_k_wos_lengths[I3]},
Ho_{c_g_n_k_wos_lengths[I4]},
Wo_{c_g_n_k_wos_lengths[I5]},
Z_{b_g_k_c_xs_lengths[I3]},
Y_{b_g_k_c_xs_lengths[I4]},
X_{b_g_k_c_xs_lengths[I5]},
K_{c_g_n_k_wos_lengths[I2]},
C_{b_g_k_c_xs_lengths[I2]},
DiStride_{a_g_n_c_wis_strides[I3]},
HiStride_{a_g_n_c_wis_strides[I4]},
WiStride_{a_g_n_c_wis_strides[I5]},
WoStride_{c_g_n_k_wos_strides[I5]},
XStride_{b_g_k_c_xs_strides[I5]},
CStrideTensorA_{a_g_n_c_wis_strides[I2]},
CStrideTensorB_{b_g_k_c_xs_strides[I2]},
KStrideTensorB_{b_g_k_c_xs_strides[I1]},
KStrideTensorC_{c_g_n_k_wos_strides[I2]},
NStrideTensorA_{a_g_n_c_wis_strides[I1]},
GStrideTensorA_{a_g_n_c_wis_strides[I0]},
GStrideTensorB_{b_g_k_c_xs_strides[I0]},
GStrideTensorC_{c_g_n_k_wos_strides[I0]},
ConvStrideD_{conv_filter_strides[I0]},
ConvStrideH_{conv_filter_strides[I1]},
ConvStrideW_{conv_filter_strides[I2]},
ConvDilationD_{conv_filter_dilations[I0]},
ConvDilationH_{conv_filter_dilations[I1]},
ConvDilationW_{conv_filter_dilations[I2]},
InLeftPadD_{input_left_pads[I0]},
InLeftPadH_{input_left_pads[I1]},
InLeftPadW_{input_left_pads[I2]},
InRightPadD_{input_right_pads[I0]},
InRightPadH_{input_right_pads[I1]},
InRightPadW_{input_right_pads[I2]},
ZYX_{Z_ * Y_ * X_}
{
static_assert(is_same_v<ConvSpatialDimsType, std::array<index_t, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<index_t, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<index_t, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<index_t, NDimSpatial + I3>>);
if constexpr(SplitN)
{
N_ = GetSplitedNSize(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
}
else
{
N_ = c_g_n_k_wos_lengths[I1];
}
NDoHoWo_ = N_ * Do_ * Ho_ * Wo_;
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties
template <typename ALayout,
......@@ -110,53 +313,26 @@ struct TransformConvFwdToGemm
is_same_v<ALayout, tensor_layout::convolution::NWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNWC>),
bool>::type = false>
static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
__host__ __device__ auto MakeADescriptor_M_K() const
{
const index_t C = a_g_n_c_wis_lengths[I2];
const index_t Wi = a_g_n_c_wis_lengths[I3];
const index_t Wo = c_g_n_k_wos_lengths[I3];
const index_t ConvStrideW = conv_filter_strides[I0];
const index_t GStride = a_g_n_c_wis_strides[I0];
const index_t NStride = a_g_n_c_wis_strides[I1];
const auto CStride = a_g_n_c_wis_strides[I2];
const index_t WiStride = a_g_n_c_wis_strides[I3];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NHoWo, C),
make_tuple(WiStride, CStride));
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
make_tuple(WiStride_, CStrideTensorA_));
}
else
{
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(NHoWo, NumGroupsToMerge, C), make_tuple(WiStride, GStride, CStride));
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)),
make_pass_through_transform(C)),
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......@@ -164,35 +340,30 @@ struct TransformConvFwdToGemm
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter3x3)
{
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_wi_c_desc =
make_naive_tensor_descriptor(make_tuple(N, Wi), make_tuple(NStride, WiStride));
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wi_), make_tuple(NStrideTensorA_, WiStride_));
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Number<3>{}, Wo),
make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return transform_tensor_descriptor(
in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
make_pass_through_transform(Number<3>{})),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -200,28 +371,29 @@ struct TransformConvFwdToGemm
else
{
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, NumGroupsToMerge), make_tuple(NStride, WiStride, GStride));
make_tuple(N_, Wi_, NumGroupsToMerge),
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_));
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Number<3>{}, Wo),
make_tuple(ConvDilationW, ConvStrideW)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return transform_tensor_descriptor(
in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)),
make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(Number<3>{})),
make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -233,110 +405,108 @@ struct TransformConvFwdToGemm
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
make_tuple(N_, Wi_, C_),
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
const auto in_n_wo_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return transform_tensor_descriptor(
in_n_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_pass_through_transform(C)),
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_n_wi_c_desc =
make_naive_tensor_descriptor(make_tuple(N, Wi, NumGroupsToMerge, C),
make_tuple(NStride, WiStride, GStride, CStride));
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
const auto in_n_wo_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return transform_tensor_descriptor(
in_n_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)),
make_pass_through_transform(C)),
make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
else
{
const index_t X = b_g_k_c_xs_lengths[3];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
make_tuple(N_, Wi_, C_),
make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_));
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return transform_tensor_descriptor(
in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(make_merge_transform(make_tuple(N_, Wo_)),
make_merge_transform(make_tuple(X_, C_))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_n_wi_c_desc =
make_naive_tensor_descriptor(make_tuple(N, Wi, NumGroupsToMerge, C),
make_tuple(NStride, WiStride, GStride, CStride));
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_));
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo),
make_tuple(ConvDilationW, ConvStrideW)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4>{}));
return transform_tensor_descriptor(
in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo, NumGroupsToMerge)),
make_merge_transform(make_tuple(X, C))),
make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(X_, C_))),
make_tuple(Sequence<0, 2, 3>{}, Sequence<1, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......@@ -349,57 +519,27 @@ struct TransformConvFwdToGemm
is_same_v<ALayout, tensor_layout::convolution::NHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNHWC>),
bool>::type = false>
static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
__host__ __device__ auto MakeADescriptor_M_K() const
{
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3];
const index_t Wi = a_g_n_c_wis_lengths[4];
const index_t Ho = c_g_n_k_wos_lengths[3];
const index_t Wo = c_g_n_k_wos_lengths[4];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t GStride = a_g_n_c_wis_strides[I0];
const index_t NStride = a_g_n_c_wis_strides[I1];
const index_t CStride = a_g_n_c_wis_strides[I2];
const index_t HiStride = a_g_n_c_wis_strides[I3];
const index_t WiStride = a_g_n_c_wis_strides[I4];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NHoWo, C),
make_tuple(WiStride, CStride));
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
make_tuple(WiStride_, CStrideTensorA_));
}
else
{
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(NHoWo, NumGroupsToMerge, C), make_tuple(WiStride, GStride, CStride));
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)),
make_pass_through_transform(C)),
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......@@ -407,73 +547,65 @@ struct TransformConvFwdToGemm
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter3x3)
{
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi), make_tuple(NStride, HiStride, WiStride));
make_tuple(N_, Hi_, Wi_), make_tuple(NStrideTensorA_, HiStride_, WiStride_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Number<3>{}, Ho),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(Number<3>{}, Wo),
make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}));
return transform_tensor_descriptor(
in_n_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_n_hi_wi_groups_c_desc =
make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, NumGroupsToMerge),
make_tuple(NStride, HiStride, WiStride, GStride));
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge),
make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_));
const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
in_n_hi_wi_groups_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor(
in_n_hip_wip_groups_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Number<3>{}, Ho),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(Number<3>{}, Wo),
make_tuple(ConvDilationW, ConvStrideW)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
return transform_tensor_descriptor(
in_n_y_ho_x_wo_groups_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)),
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -485,37 +617,39 @@ struct TransformConvFwdToGemm
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
const auto in_n_ho_wo_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)),
make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return transform_tensor_descriptor(
in_n_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, NumGroupsToMerge, C),
make_tuple(NStride, HiStride, WiStride, GStride, CStride));
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
const auto in_n_ho_wo_groups_c_desc = transform_tensor_descriptor(
in_n_hi_wi_groups_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)),
make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
......@@ -523,55 +657,44 @@ struct TransformConvFwdToGemm
return transform_tensor_descriptor(
in_n_ho_wo_groups_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)),
make_pass_through_transform(C)),
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
else
{
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
return transform_tensor_descriptor(
in_n_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)),
make_merge_transform(make_tuple(Y_, X_, C_))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......@@ -579,16 +702,17 @@ struct TransformConvFwdToGemm
{
const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, NumGroupsToMerge, C),
make_tuple(NStride, HiStride, WiStride, GStride, CStride));
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(
NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_));
const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
in_n_hi_wi_groups_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
......@@ -596,13 +720,13 @@ struct TransformConvFwdToGemm
const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor(
in_n_hip_wip_groups_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo),
make_tuple(ConvDilationW, ConvStrideW)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
......@@ -613,8 +737,8 @@ struct TransformConvFwdToGemm
return transform_tensor_descriptor(
in_n_y_ho_x_wo_groups_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, NumGroupsToMerge)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(Y_, X_, C_))),
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......@@ -627,63 +751,27 @@ struct TransformConvFwdToGemm
is_same_v<ALayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNDHWC>),
bool>::type = false>
static auto
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides*/,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t N)
__host__ __device__ auto MakeADescriptor_M_K() const
{
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Di = a_g_n_c_wis_lengths[3];
const index_t Hi = a_g_n_c_wis_lengths[4];
const index_t Wi = a_g_n_c_wis_lengths[5];
const index_t Do = c_g_n_k_wos_lengths[3];
const index_t Ho = c_g_n_k_wos_lengths[4];
const index_t Wo = c_g_n_k_wos_lengths[5];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
const index_t GStride = a_g_n_c_wis_strides[I0];
const index_t NStride = a_g_n_c_wis_strides[I1];
const index_t CStride = a_g_n_c_wis_strides[I2];
const index_t DiStride = a_g_n_c_wis_strides[I3];
const index_t HiStride = a_g_n_c_wis_strides[I4];
const index_t WiStride = a_g_n_c_wis_strides[I5];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NDoHoWo, C),
make_tuple(WiStride, CStride));
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, C_),
make_tuple(WiStride_, CStrideTensorA_));
}
else
{
const auto in_gemmm_groups_gemmk_desc =
make_naive_tensor_descriptor(make_tuple(NDoHoWo, NumGroupsToMerge, C),
make_tuple(WiStride, GStride, CStride));
const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(NDoHoWo_, NumGroupsToMerge, C_),
make_tuple(WiStride_, GStrideTensorA_, CStrideTensorA_));
return transform_tensor_descriptor(
in_gemmm_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
make_pass_through_transform(C)),
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......@@ -691,41 +779,30 @@ struct TransformConvFwdToGemm
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter3x3)
{
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi), make_tuple(NStride, DiStride, HiStride, WiStride));
make_tuple(N_, Di_, Hi_, Wi_),
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Number<3>{}, Do),
make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Number<3>{}, Ho),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(Number<3>{}, Wo),
make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Do_),
make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Number<3>{}, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5, 6>{}));
......@@ -733,7 +810,7 @@ struct TransformConvFwdToGemm
return transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(
make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -741,15 +818,15 @@ struct TransformConvFwdToGemm
else
{
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, NumGroupsToMerge),
make_tuple(NStride, DiStride, HiStride, WiStride, GStride));
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge),
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
......@@ -758,13 +835,13 @@ struct TransformConvFwdToGemm
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Number<3>{}, Do),
make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Number<3>{}, Ho),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(Number<3>{}, Wo),
make_tuple(ConvDilationW, ConvStrideW)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Number<3>{}, Do_),
make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Number<3>{}, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(Number<3>{}, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
......@@ -777,7 +854,7 @@ struct TransformConvFwdToGemm
return transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(
make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)),
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))),
make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -789,16 +866,16 @@ struct TransformConvFwdToGemm
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)),
make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)),
make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
......@@ -806,25 +883,30 @@ struct TransformConvFwdToGemm
return transform_tensor_descriptor(
in_n_do_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C),
make_tuple(NStride, DiStride, HiStride, WiStride, GStride, CStride));
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStrideTensorA_,
DiStride_,
HiStride_,
WiStride_,
GStrideTensorA_,
CStrideTensorA_));
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)),
make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)),
make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -840,43 +922,28 @@ struct TransformConvFwdToGemm
return transform_tensor_descriptor(
in_n_do_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)),
make_pass_through_transform(C)),
make_tuple(
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
else
{
const index_t Z = b_g_k_c_xs_lengths[3];
const index_t Y = b_g_k_c_xs_lengths[4];
const index_t X = b_g_k_c_xs_lengths[5];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
if constexpr(NumGroupsToMerge == 1)
{
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
......@@ -884,14 +951,14 @@ struct TransformConvFwdToGemm
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do),
make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Z_, Do_),
make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
......@@ -902,25 +969,30 @@ struct TransformConvFwdToGemm
return transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
make_merge_transform(make_tuple(Z_, Y_, X_, C_))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C),
make_tuple(NStride, DiStride, HiStride, WiStride, GStride, CStride));
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStrideTensorA_,
DiStride_,
HiStride_,
WiStride_,
GStrideTensorA_,
CStrideTensorA_));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_tuple(make_pass_through_transform(N_),
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -936,15 +1008,15 @@ struct TransformConvFwdToGemm
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do),
make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo),
make_tuple(ConvDilationW, ConvStrideW)),
make_tuple(make_pass_through_transform(N_),
make_embed_transform(make_tuple(Z_, Do_),
make_tuple(ConvDilationD_, ConvStrideD_)),
make_embed_transform(make_tuple(Y_, Ho_),
make_tuple(ConvDilationH_, ConvStrideH_)),
make_embed_transform(make_tuple(X_, Wo_),
make_tuple(ConvDilationW_, ConvStrideW_)),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C)),
make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -960,8 +1032,9 @@ struct TransformConvFwdToGemm
return transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, NumGroupsToMerge)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(Z_, Y_, X_, C_))),
make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5, 8>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......@@ -973,19 +1046,8 @@ struct TransformConvFwdToGemm
is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>,
bool>::type = false>
static auto MakeBDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
__host__ __device__ auto MakeBDescriptor_N_K() const
{
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX = ck::accumulate_n<index_t>(
b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const index_t GStride = b_g_k_c_xs_strides[I0];
const index_t KStride = b_g_k_c_xs_strides[I1];
const index_t CStride = b_g_k_c_xs_strides[I2];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter3x3)
{
......@@ -996,17 +1058,17 @@ struct TransformConvFwdToGemm
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor_packed(make_tuple(K, FilterSizeNumType{}));
return make_naive_tensor_descriptor_packed(make_tuple(K_, FilterSizeNumType{}));
}
else
{
const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(K, NumGroupsToMerge, FilterSizeNumType{}),
make_tuple(KStride, GStride, CStride));
make_tuple(K_, NumGroupsToMerge, FilterSizeNumType{}),
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_));
return transform_tensor_descriptor(
wei_gemmn_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(K, NumGroupsToMerge)),
make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
make_pass_through_transform(FilterSizeNumType{})),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -1016,16 +1078,17 @@ struct TransformConvFwdToGemm
{
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor_packed(make_tuple(K, YX * C));
return make_naive_tensor_descriptor_packed(make_tuple(K_, ZYX_ * C_));
}
else
{
const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
make_tuple(K, NumGroupsToMerge, YX * C), make_tuple(KStride, GStride, CStride));
make_tuple(K_, NumGroupsToMerge, ZYX_ * C_),
make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_));
return transform_tensor_descriptor(
wei_gemmn_groups_gemmk_desc,
make_tuple(make_merge_transform(make_tuple(K, NumGroupsToMerge)),
make_pass_through_transform(YX * C)),
make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
make_pass_through_transform(ZYX_ * C_)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......@@ -1041,25 +1104,14 @@ struct TransformConvFwdToGemm
is_same_v<BLayout, tensor_layout::convolution::KYXGC> ||
is_same_v<BLayout, tensor_layout::convolution::KZYXGC>,
bool>::type = false>
static auto MakeBDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
__host__ __device__ auto MakeBDescriptor_N_K() const
{
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX = ck::accumulate_n<index_t>(
b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const index_t KStride = b_g_k_c_xs_strides[1];
const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto wei_k_yx_c_desc = make_naive_tensor_descriptor(
make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride));
make_tuple(K_, ZYX_, C_), make_tuple(KStrideTensorB_, XStride_, CStrideTensorB_));
const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor(
wei_k_yx_c_desc,
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))),
make_tuple(make_pass_through_transform(K_), make_merge_transform(make_tuple(ZYX_, C_))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......@@ -1071,24 +1123,14 @@ struct TransformConvFwdToGemm
is_same_v<CLayout, tensor_layout::convolution::GNHWK> ||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>,
bool>::type = false>
static auto
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const index_t N)
__host__ __device__ auto MakeCDescriptor_M_N() const
{
const index_t K = c_g_n_k_wos_lengths[2];
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K));
return out_gemmm_gemmn_desc;
return make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo_, K_));
}
template <
typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> ||
......@@ -1096,39 +1138,28 @@ struct TransformConvFwdToGemm
is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
bool>::type = false>
static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const index_t N)
__host__ __device__ auto MakeCDescriptor_M_N() const
{
const index_t K = c_g_n_k_wos_lengths[2];
const index_t KStride = I1;
const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2];
const index_t GStride = c_g_n_k_wos_strides[0];
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
if constexpr(NumGroupsToMerge == 1)
{
return make_naive_tensor_descriptor(make_tuple(NHoWo, K),
make_tuple(WoStride, KStride));
return make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_),
make_tuple(WoStride_, KStrideTensorC_));
}
else
{
const auto nhwo_groups_k_1_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, NumGroupsToMerge, K, 1),
make_tuple(WoStride, GStride, KStride, GStride));
const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor(
make_tuple(NDoHoWo_, NumGroupsToMerge, K_, 1),
make_tuple(WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_));
// Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
nhwo_groups_k_1_desc,
make_tuple(make_pass_through_transform(NHoWo),
make_tuple(make_pass_through_transform(NDoHoWo_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(K),
make_pass_through_transform(K_),
make_pad_transform(1, 0, NumGroupsToMerge - 1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// We need only matrices from diagonal. X_or returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
......@@ -1136,16 +1167,16 @@ struct TransformConvFwdToGemm
NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_pass_through_transform(NHoWo),
make_tuple(make_pass_through_transform(NDoHoWo_),
make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
make_pass_through_transform(K)),
make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}));
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NHoWo, NumGroupsToMerge)),
make_merge_transform(make_tuple(K, NumGroupsToMerge))),
make_tuple(make_merge_transform(make_tuple(NDoHoWo_, NumGroupsToMerge)),
make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
......@@ -1155,542 +1186,34 @@ struct TransformConvFwdToGemm
template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>,
bool>::type = false>
static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const index_t N)
__host__ __device__ auto MakeCDescriptor_M_N() const
{
const index_t K = c_g_n_k_wos_lengths[2];
const index_t KStride = c_g_n_k_wos_strides[2];
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, KStride));
make_naive_tensor_descriptor(make_tuple(NDoHoWo_, K_), make_tuple(I0, KStrideTensorC_));
return out_gemmm_gemmn_desc;
}
// Overloaded functions for hipRTC purposes
template <typename ALayout,
typename std::enable_if<NDimSpatial == 1 &&
(is_same_v<ALayout, tensor_layout::convolution::G_NW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNWC>),
bool>::type = false>
__host__ __device__ static auto
MakeADescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Wi = a_g_n_c_wis_lengths[3];
const index_t Wo = c_g_n_k_wos_lengths[3];
const index_t ConvStrideW = conv_filter_strides[0];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t WiStride = a_g_n_c_wis_strides[3];
const auto CStride = I1;
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
const auto in_n_wo_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
else
{
const index_t X = b_g_k_c_xs_lengths[3];
const index_t ConvDilationW = conv_filter_dilations[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t WiStride = a_g_n_c_wis_strides[3];
const auto CStride = I1;
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
const auto in_n_wip_c_desc = transform_tensor_descriptor(
in_n_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
in_n_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Wo)),
make_merge_transform(make_tuple(X, C))),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
}
template <typename ALayout,
typename std::enable_if<
NDimSpatial == 2 && (is_same_v<ALayout, tensor_layout::convolution::G_NHW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNHWC>),
bool>::type = false>
__host__ __device__ static auto
MakeADescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Hi = a_g_n_c_wis_lengths[3];
const index_t Wi = a_g_n_c_wis_lengths[4];
const index_t Ho = c_g_n_k_wos_lengths[3];
const index_t Wo = c_g_n_k_wos_lengths[4];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t HiStride = a_g_n_c_wis_strides[3];
const index_t WiStride = a_g_n_c_wis_strides[4];
const auto CStride = I1;
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
const auto in_n_ho_wo_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
else
{
const index_t Y = b_g_k_c_xs_lengths[3];
const index_t X = b_g_k_c_xs_lengths[4];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t HiStride = a_g_n_c_wis_strides[3];
const index_t WiStride = a_g_n_c_wis_strides[4];
const auto CStride = I1;
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmm_gemmk_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
make_merge_transform(make_tuple(Y, X, C))),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
}
template <typename ALayout,
typename std::enable_if<
NDimSpatial == 3 && (is_same_v<ALayout, tensor_layout::convolution::G_NDHW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNDHWC>),
bool>::type = false>
static auto
MakeADescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads)
{
const index_t N = a_g_n_c_wis_lengths[1];
const index_t C = a_g_n_c_wis_lengths[2];
const index_t Di = a_g_n_c_wis_lengths[3];
const index_t Hi = a_g_n_c_wis_lengths[4];
const index_t Wi = a_g_n_c_wis_lengths[5];
const index_t Do = c_g_n_k_wos_lengths[3];
const index_t Ho = c_g_n_k_wos_lengths[4];
const index_t Wo = c_g_n_k_wos_lengths[5];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>());
// This is different
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto in_gemmm_gemmk_desc =
make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride));
return in_gemmm_gemmk_desc;
}
else if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t DiStride = a_g_n_c_wis_strides[3];
const index_t HiStride = a_g_n_c_wis_strides[4];
const index_t WiStride = a_g_n_c_wis_strides[5];
const auto CStride = I1;
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_do_ho_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
else
{
const index_t Z = b_g_k_c_xs_lengths[3];
const index_t Y = b_g_k_c_xs_lengths[4];
const index_t X = b_g_k_c_xs_lengths[5];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
// This is different
const index_t NStride = a_g_n_c_wis_strides[1];
const index_t DiStride = a_g_n_c_wis_strides[3];
const index_t HiStride = a_g_n_c_wis_strides[4];
const index_t WiStride = a_g_n_c_wis_strides[5];
const auto CStride = I1;
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, C),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_desc,
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_merge_transform(make_tuple(Z, Y, X, C))),
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return in_gemmm_gemmk_desc;
}
}
template <typename BLayout,
typename std::enable_if<is_same_v<BLayout, tensor_layout::convolution::GKXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>,
bool>::type = false>
__host__ __device__ static auto
MakeBDescriptor_N_K(const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */)
{
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX =
mult_accumulate_n<index_t>(b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1);
const auto wei_gemmn_gemmk_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, YX * C));
return wei_gemmn_gemmk_desc;
}
template <
typename BLayout,
typename std::enable_if<is_same_v<BLayout, tensor_layout::convolution::G_K_X_C> ||
is_same_v<BLayout, tensor_layout::convolution::G_K_YX_C> ||
is_same_v<BLayout, tensor_layout::convolution::G_K_ZYX_C> ||
is_same_v<BLayout, tensor_layout::convolution::KXGC> ||
is_same_v<BLayout, tensor_layout::convolution::KYXGC> ||
is_same_v<BLayout, tensor_layout::convolution::KZYXGC>,
bool>::type = false>
__host__ __device__ static auto
MakeBDescriptor_N_K(const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{
const index_t K = b_g_k_c_xs_lengths[1];
const index_t C = b_g_k_c_xs_lengths[2];
const index_t YX =
mult_accumulate_n<index_t>(b_g_k_c_xs_lengths.begin() + 3, NDimSpatial, 1);
const index_t KStride = b_g_k_c_xs_strides[1];
const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial];
const auto CStride = I1;
const auto wei_k_yx_c_desc = make_naive_tensor_descriptor(
make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride));
const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor(
wei_k_yx_c_desc,
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return wei_gemmn_gemmk_desc;
}
template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::GNWK> ||
is_same_v<CLayout, tensor_layout::convolution::GNHWK> ||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>,
bool>::type = false>
__host__ __device__ static auto
MakeCDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */)
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const index_t NHoWo =
N * mult_accumulate_n<index_t>(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1);
const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K));
return out_gemmm_gemmn_desc;
}
template <
typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
bool>::type = false>
__host__ __device__ static auto
MakeCDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const auto KStride = I1;
const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2];
const index_t NHoWo =
N * mult_accumulate_n<index_t>(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1);
const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride));
return out_gemmm_gemmn_desc;
}
// for output bias
template <typename CLayout,
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_K>,
bool>::type = false>
__host__ __device__ static auto
MakeCDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{
const index_t N = c_g_n_k_wos_lengths[1];
const index_t K = c_g_n_k_wos_lengths[2];
const index_t KStride = c_g_n_k_wos_strides[2];
const index_t NHoWo =
N * mult_accumulate_n<index_t>(c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1);
const auto out_gemmm_gemmn_desc =
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, KStride));
return out_gemmm_gemmn_desc;
}
public:
index_t N_;
private:
const index_t Di_, Hi_, Wi_;
const index_t Do_, Ho_, Wo_;
const index_t Z_, Y_, X_;
const index_t K_, C_;
const index_t DiStride_, HiStride_, WiStride_;
const index_t WoStride_;
const index_t XStride_;
const index_t CStrideTensorA_, CStrideTensorB_, KStrideTensorB_, KStrideTensorC_;
const index_t NStrideTensorA_;
const index_t GStrideTensorA_, GStrideTensorB_, GStrideTensorC_;
const index_t ConvStrideD_, ConvStrideH_, ConvStrideW_;
const index_t ConvDilationD_, ConvDilationH_, ConvDilationW_;
const index_t InLeftPadD_, InLeftPadH_, InLeftPadW_;
const index_t InRightPadD_, InRightPadH_, InRightPadW_;
const index_t ZYX_;
index_t NDoHoWo_;
};
// wrapper class to call member functions on TransformConvToGemm struct at runtime
......@@ -1702,26 +1225,22 @@ struct TransformConv
template <index_t NDimSpatial,
device::ConvolutionForwardSpecialization ConvForwardSpecialization>
auto
transform_func(ck::Array<index_t, NDimSpatial + 3> out_lengths,
ck::Array<index_t, NDimSpatial + 3> out_strides,
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization> conv_fwd_to_gemm)
transform_func(TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization> conv_fwd_to_gemm)
{
if(NDimSpatial == 2)
{
return conv_fwd_to_gemm
.template MakeCDescriptor_M_N<ck::tensor_layout::convolution::NHWGK>(out_lengths,
out_strides);
.template MakeCDescriptor_M_N<ck::tensor_layout::convolution::NHWGK>();
}
else if(NDimSpatial == 3)
{
return conv_fwd_to_gemm
.template MakeCDescriptor_M_N<tensor_layout::convolution::NDHWGK>(out_lengths,
out_strides);
.template MakeCDescriptor_M_N<tensor_layout::convolution::NDHWGK>();
}
else if(NDimSpatial == 1)
{
return conv_fwd_to_gemm.template MakeCDescriptor_M_N<tensor_layout::convolution::NWGK>(
out_lengths, out_strides);
return conv_fwd_to_gemm
.template MakeCDescriptor_M_N<tensor_layout::convolution::NWGK>();
}
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment