Commit b9eb4de3 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents be3fbf7f d22713a7
...@@ -38,6 +38,41 @@ struct DeviceGemmV2 : public BaseOperator ...@@ -38,6 +38,41 @@ struct DeviceGemmV2 : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmV2R1 : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> DsStrides,
ck::index_t StrideC,
ck::index_t KSplit,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename DsDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename OutElementwiseOperation>
struct DeviceReduceMultiD : public BaseOperator
{
static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides,
const std::array<std::array<index_t, NumOutDim>, NumDTensor> DsLengths,
const std::array<std::array<index_t, NumOutDim>, NumDTensor> DsStrides,
const std::array<index_t, NumOutDim> outLengths,
const std::array<index_t, NumOutDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
const void* in_dev,
const std::array<const void*, NumDTensor> ds_dev,
void* out_dev,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation out_elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename InDataType,
typename DsDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename OutElementwiseOperation>
using DeviceReduceMultiDPtr = std::unique_ptr<DeviceReduceMultiD<InDataType,
DsDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
OutElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
...@@ -95,16 +98,27 @@ auto transform_conv(ck::index_t num_dim, ...@@ -95,16 +98,27 @@ auto transform_conv(ck::index_t num_dim,
ck::Array<ck::index_t, 5> out_lengths, ck::Array<ck::index_t, 5> out_lengths,
ck::Array<ck::index_t, 5> out_strides) ck::Array<ck::index_t, 5> out_strides)
{ {
ck::Array<ck::index_t, 5> dummy_dims;
ck::Array<ck::index_t, 2> dummy_spatial_dims;
if(num_dim == 2 && if(num_dim == 2 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
{ {
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
2, 2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 2 && if(num_dim == 2 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
...@@ -112,10 +126,19 @@ auto transform_conv(ck::index_t num_dim, ...@@ -112,10 +126,19 @@ auto transform_conv(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
2, 2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 2 && if(num_dim == 2 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
...@@ -123,20 +146,38 @@ auto transform_conv(ck::index_t num_dim, ...@@ -123,20 +146,38 @@ auto transform_conv(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
2, 2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
{ {
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
2, 2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
throw std::runtime_error("Incorrect conv spec"); throw std::runtime_error("Incorrect conv spec");
} }
...@@ -146,16 +187,28 @@ auto transform_conv_3d(ck::index_t num_dim, ...@@ -146,16 +187,28 @@ auto transform_conv_3d(ck::index_t num_dim,
ck::Array<ck::index_t, 6> out_lengths, ck::Array<ck::index_t, 6> out_lengths,
ck::Array<ck::index_t, 6> out_strides) ck::Array<ck::index_t, 6> out_strides)
{ {
ck::Array<ck::index_t, 6> dummy_dims;
ck::Array<ck::index_t, 3> dummy_spatial_dims;
if(num_dim == 3 && if(num_dim == 3 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
{ {
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
3, 3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 3 && if(num_dim == 3 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
...@@ -163,10 +216,19 @@ auto transform_conv_3d(ck::index_t num_dim, ...@@ -163,10 +216,19 @@ auto transform_conv_3d(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
3, 3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 3 && if(num_dim == 3 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
...@@ -174,20 +236,38 @@ auto transform_conv_3d(ck::index_t num_dim, ...@@ -174,20 +236,38 @@ auto transform_conv_3d(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
3, 3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
{ {
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
3, 3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
throw std::runtime_error("Incorrect conv spec"); throw std::runtime_error("Incorrect conv spec");
} }
...@@ -197,16 +277,28 @@ auto transform_conv_1d(ck::index_t num_dim, ...@@ -197,16 +277,28 @@ auto transform_conv_1d(ck::index_t num_dim,
ck::Array<ck::index_t, 4> out_lengths, ck::Array<ck::index_t, 4> out_lengths,
ck::Array<ck::index_t, 4> out_strides) ck::Array<ck::index_t, 4> out_strides)
{ {
ck::Array<ck::index_t, 4> dummy_dims;
ck::Array<ck::index_t, 1> dummy_spatial_dims;
if(num_dim == 1 && if(num_dim == 1 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
{ {
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
1, 1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 1 && if(num_dim == 1 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
...@@ -214,10 +306,19 @@ auto transform_conv_1d(ck::index_t num_dim, ...@@ -214,10 +306,19 @@ auto transform_conv_1d(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
1, 1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 1 && if(num_dim == 1 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
...@@ -225,20 +326,38 @@ auto transform_conv_1d(ck::index_t num_dim, ...@@ -225,20 +326,38 @@ auto transform_conv_1d(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
1, 1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
{ {
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
1, 1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
throw std::runtime_error("Incorrect dims or conv spec"); throw std::runtime_error("Incorrect dims or conv spec");
} }
......
...@@ -359,36 +359,17 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -359,36 +359,17 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer = using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay> template <typename ALay>
__host__ __device__ static auto __host__ __device__ static auto
MakeAGridDescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths, conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -398,12 +379,10 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -398,12 +379,10 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename BLay> template <typename BLay>
__host__ __device__ static auto __host__ __device__ static auto
MakeBGridDescriptor_N_K(const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths, conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
b_g_k_c_xs_strides);
const auto wei_gemmn_gemmk_desc = const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
...@@ -413,12 +392,10 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -413,12 +392,10 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename ELay> template <typename ELay>
__host__ __device__ static auto __host__ __device__ static auto
MakeEGridDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
e_g_n_k_wos_strides);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
...@@ -428,26 +405,27 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -428,26 +405,27 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Shape of Ds and E must be aligned. Strides can be different. // Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast. // Pass e_g_n_k_wos_lengths for logical broadcast.
__host__ __device__ static auto MakeDsGridDescriptor_M_N( static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(e_g_n_k_wos_lengths, return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
ds_g_n_k_wos_strides[i]);
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
// desc for problem definition // desc for problem definition
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>( constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; using AGridDesc_M_K =
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>; remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using BGridDesc_N_K =
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert // If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it // it to it
...@@ -533,21 +511,23 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -533,21 +511,23 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]}, num_group_{a_g_n_c_wis_lengths[0]},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths, conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
e_g_n_k_wos_lengths, e_g_n_k_wos_lengths,
e_g_n_k_wos_strides, e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads)}, input_right_pads},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths, a_grid_desc_m_k_{
b_g_k_c_xs_strides)}, DeviceOp::MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_n_k_{
DeviceOp::MakeBGridDescriptor_N_K<BLayout>(conv_to_gemm_transformer_)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths, e_grid_desc_m_n_{
e_g_n_k_wos_strides)}, DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
...@@ -637,9 +617,20 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -637,9 +617,20 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// D batch stride // D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
ds_g_n_k_wos_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// D desc // D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>( ds_grid_desc_m_n_(i) =
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]); DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
}); });
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
...@@ -694,6 +685,9 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -694,6 +685,9 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp" #include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
...@@ -65,8 +64,8 @@ struct DeviceColumnToImageImpl ...@@ -65,8 +64,8 @@ struct DeviceColumnToImageImpl
static constexpr auto spatial_offset = Number<3>{}; static constexpr auto spatial_offset = Number<3>{};
static constexpr auto conv_to_gemm_transformer = using GemmToConvFwdTransformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>{}; TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{ MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{
MPerBlock, 0 /* NPerBlock*/, KPerBlock}; MPerBlock, 0 /* NPerBlock*/, KPerBlock};
...@@ -234,21 +233,21 @@ struct DeviceColumnToImageImpl ...@@ -234,21 +233,21 @@ struct DeviceColumnToImageImpl
: independent_filter_stride; : independent_filter_stride;
} }
GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
image_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor
c_g_n_k_wos_lengths,
{}, // not needed for A Descriptor
// conv_filter_strides,
independent_filter_strides,
conv_filter_dilations,
input_left_pads_with_offset,
input_right_pads};
// Calculate image form descriptor for the modified convolution problem // Calculate image form descriptor for the modified convolution problem
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>( conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>();
a_g_n_c_wis_lengths,
image_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor
c_g_n_k_wos_lengths,
{}, // not needed for A Descriptor
// conv_filter_strides,
independent_filter_strides,
conv_filter_dilations,
input_left_pads_with_offset,
input_right_pads,
N);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -182,18 +182,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -182,18 +182,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
#if 0
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
#endif
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
...@@ -206,121 +194,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -206,121 +194,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number could be One to Seven // Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{ {
#if 0
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
else
#endif
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{ {
...@@ -436,32 +309,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -436,32 +309,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number could be Odd or Even // Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{ {
#if 0
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
#endif
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
...@@ -487,32 +335,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -487,32 +335,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
} }
else else
{ {
#if 0
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
#endif
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
...@@ -542,18 +364,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -542,18 +364,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number always 1 // Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
#if 0
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
#endif
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
......
...@@ -238,37 +238,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -238,37 +238,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer = using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
template <typename ALay> template <typename ALay>
static auto static auto
MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths, conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -286,12 +266,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -286,12 +266,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
template <typename BLay> template <typename BLay>
static auto static auto
MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths, conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
b_g_k_c_xs_strides);
const auto wei_gemmn_gemmk_desc = const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
...@@ -309,13 +287,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -309,13 +287,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
} }
template <typename ELay> template <typename ELay>
static auto static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>( conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
...@@ -323,27 +298,27 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -323,27 +298,27 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
return out_gemmm_gemmn_desc; return out_gemmm_gemmn_desc;
} }
static auto MakeDsGridDescriptor_M_N( static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(ds_g_n_k_wos_lengths[i], return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
ds_g_n_k_wos_strides[i]);
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
// desc for problem definition // desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>; dummy_conv_to_gemm_transformer))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using DsGridDesc_M_N =
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = using GridwiseGemm =
...@@ -426,21 +401,22 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -426,21 +401,22 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]}, num_group_{a_g_n_c_wis_lengths[0]},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths, DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
a_g_n_c_wis_strides, b_grid_desc_bk0_n_bk1_{
b_g_k_c_xs_lengths, DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(conv_to_gemm_transformer_)},
b_g_k_c_xs_strides, e_grid_desc_m_n_{
e_g_n_k_wos_lengths, DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_k0_m0_m1_k1_{}, a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{}, b_grid_desc_k0_n0_n1_k1_{},
ds_grid_desc_m0_m10_m11_n0_n10_n11_{}, ds_grid_desc_m0_m10_m11_n0_n10_n11_{},
...@@ -471,6 +447,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -471,6 +447,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths[i],
ds_g_n_k_wos_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// D pointer // D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]); p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
...@@ -478,8 +465,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -478,8 +465,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
// D desc // D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>( ds_grid_desc_m_n_(i) =
ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
}); });
// populate desc for Ds/E // populate desc for Ds/E
...@@ -523,6 +510,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -523,6 +510,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
......
...@@ -234,37 +234,17 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -234,37 +234,17 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer = using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
template <typename ALay> template <typename ALay>
static auto static auto
MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths, conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
c_g_n_k_wos_lengths,
c_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -283,12 +263,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -283,12 +263,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
template <typename BLay> template <typename BLay>
static auto static auto
MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths, conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
b_g_k_c_xs_strides);
const auto wei_gemmn_gemmk_desc = const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
...@@ -306,13 +284,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -306,13 +284,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
} }
template <typename CLay> template <typename CLay>
static auto static auto MakeCGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeCGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>( conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>();
c_g_n_k_wos_lengths, c_g_n_k_wos_strides, c_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
...@@ -321,11 +296,13 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -321,11 +296,13 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
} }
// desc for problem definition // desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>; dummy_conv_to_gemm_transformer))>;
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N<CLayout>({}, {}))>; using CGridDesc_M_N =
remove_cvref_t<decltype(MakeCGridDescriptor_M_N<CLayout>(dummy_conv_to_gemm_transformer))>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = using GridwiseGemm =
...@@ -396,21 +373,22 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -396,21 +373,22 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
p_b_grid_{static_cast<const BDataType*>(p_b)}, p_b_grid_{static_cast<const BDataType*>(p_b)},
p_c_grid_{static_cast<CDataType*>(p_c)}, p_c_grid_{static_cast<CDataType*>(p_c)},
num_group_{a_g_n_c_wis_lengths[0]}, num_group_{a_g_n_c_wis_lengths[0]},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths, DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
a_g_n_c_wis_strides, b_grid_desc_bk0_n_bk1_{
b_g_k_c_xs_lengths, DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(conv_to_gemm_transformer_)},
b_g_k_c_xs_strides, c_grid_desc_m_n_{
c_g_n_k_wos_lengths, DeviceOp::MakeCGridDescriptor_M_N<CLayout>(conv_to_gemm_transformer_)},
c_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N<CLayout>(c_g_n_k_wos_lengths,
c_g_n_k_wos_strides)},
a_grid_desc_k0_m0_m1_k1_{}, a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{}, b_grid_desc_k0_n0_n1_k1_{},
c_grid_desc_m0_m10_m11_n0_n10_n11_{}, c_grid_desc_m0_m10_m11_n0_n10_n11_{},
...@@ -473,6 +451,9 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -473,6 +451,9 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment