"tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py" did not exist on "b3e5cd6b4d7fd5d03d75e78688dce52be02217b3"
Commit b9eb4de3 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents be3fbf7f d22713a7
...@@ -38,6 +38,41 @@ struct DeviceGemmV2 : public BaseOperator ...@@ -38,6 +38,41 @@ struct DeviceGemmV2 : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmV2R1 : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
std::array<ck::index_t, NumDTensor> DsStrides,
ck::index_t StrideC,
ck::index_t KSplit,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename DsDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename OutElementwiseOperation>
struct DeviceReduceMultiD : public BaseOperator
{
static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim;
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides,
const std::array<std::array<index_t, NumOutDim>, NumDTensor> DsLengths,
const std::array<std::array<index_t, NumOutDim>, NumDTensor> DsStrides,
const std::array<index_t, NumOutDim> outLengths,
const std::array<index_t, NumOutDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
const void* in_dev,
const std::array<const void*, NumDTensor> ds_dev,
void* out_dev,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation out_elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename InDataType,
typename DsDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename OutElementwiseOperation>
using DeviceReduceMultiDPtr = std::unique_ptr<DeviceReduceMultiD<InDataType,
DsDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
OutElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
...@@ -95,16 +98,27 @@ auto transform_conv(ck::index_t num_dim, ...@@ -95,16 +98,27 @@ auto transform_conv(ck::index_t num_dim,
ck::Array<ck::index_t, 5> out_lengths, ck::Array<ck::index_t, 5> out_lengths,
ck::Array<ck::index_t, 5> out_strides) ck::Array<ck::index_t, 5> out_strides)
{ {
ck::Array<ck::index_t, 5> dummy_dims;
ck::Array<ck::index_t, 2> dummy_spatial_dims;
if(num_dim == 2 && if(num_dim == 2 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
{ {
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
2, 2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 2 && if(num_dim == 2 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
...@@ -112,10 +126,19 @@ auto transform_conv(ck::index_t num_dim, ...@@ -112,10 +126,19 @@ auto transform_conv(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
2, 2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 2 && if(num_dim == 2 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
...@@ -123,20 +146,38 @@ auto transform_conv(ck::index_t num_dim, ...@@ -123,20 +146,38 @@ auto transform_conv(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
2, 2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
{ {
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
2, 2,
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
throw std::runtime_error("Incorrect conv spec"); throw std::runtime_error("Incorrect conv spec");
} }
...@@ -146,16 +187,28 @@ auto transform_conv_3d(ck::index_t num_dim, ...@@ -146,16 +187,28 @@ auto transform_conv_3d(ck::index_t num_dim,
ck::Array<ck::index_t, 6> out_lengths, ck::Array<ck::index_t, 6> out_lengths,
ck::Array<ck::index_t, 6> out_strides) ck::Array<ck::index_t, 6> out_strides)
{ {
ck::Array<ck::index_t, 6> dummy_dims;
ck::Array<ck::index_t, 3> dummy_spatial_dims;
if(num_dim == 3 && if(num_dim == 3 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
{ {
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
3, 3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 3 && if(num_dim == 3 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
...@@ -163,10 +216,19 @@ auto transform_conv_3d(ck::index_t num_dim, ...@@ -163,10 +216,19 @@ auto transform_conv_3d(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
3, 3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 3 && if(num_dim == 3 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
...@@ -174,20 +236,38 @@ auto transform_conv_3d(ck::index_t num_dim, ...@@ -174,20 +236,38 @@ auto transform_conv_3d(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
3, 3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
{ {
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
3, 3,
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
throw std::runtime_error("Incorrect conv spec"); throw std::runtime_error("Incorrect conv spec");
} }
...@@ -197,16 +277,28 @@ auto transform_conv_1d(ck::index_t num_dim, ...@@ -197,16 +277,28 @@ auto transform_conv_1d(ck::index_t num_dim,
ck::Array<ck::index_t, 4> out_lengths, ck::Array<ck::index_t, 4> out_lengths,
ck::Array<ck::index_t, 4> out_strides) ck::Array<ck::index_t, 4> out_strides)
{ {
ck::Array<ck::index_t, 4> dummy_dims;
ck::Array<ck::index_t, 1> dummy_spatial_dims;
if(num_dim == 1 && if(num_dim == 1 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
{ {
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
1, 1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 1 && if(num_dim == 1 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
...@@ -214,10 +306,19 @@ auto transform_conv_1d(ck::index_t num_dim, ...@@ -214,10 +306,19 @@ auto transform_conv_1d(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
1, 1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 1 && if(num_dim == 1 &&
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
...@@ -225,20 +326,38 @@ auto transform_conv_1d(ck::index_t num_dim, ...@@ -225,20 +326,38 @@ auto transform_conv_1d(ck::index_t num_dim,
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
1, 1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
{ {
ck::tensor_operation::TransformConvFwdToGemm< ck::tensor_operation::TransformConvFwdToGemm<
1, 1,
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
conv_fwd; conv_fwd{dummy_dims,
dummy_dims,
dummy_dims,
dummy_dims,
out_lengths,
out_strides,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims,
dummy_spatial_dims};
auto res = ck::tensor_operation::TransformConv(); auto res = ck::tensor_operation::TransformConv();
return res.transform_func(out_lengths, out_strides, conv_fwd); return res.transform_func(conv_fwd);
} }
throw std::runtime_error("Incorrect dims or conv spec"); throw std::runtime_error("Incorrect dims or conv spec");
} }
......
...@@ -359,36 +359,17 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -359,36 +359,17 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer = using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay> template <typename ALay>
__host__ __device__ static auto __host__ __device__ static auto
MakeAGridDescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<index_t, NDimSpatial>& input_left_pads,
const ck::Array<index_t, NDimSpatial>& input_right_pads)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths, conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -398,12 +379,10 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -398,12 +379,10 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename BLay> template <typename BLay>
__host__ __device__ static auto __host__ __device__ static auto
MakeBGridDescriptor_N_K(const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths, conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
b_g_k_c_xs_strides);
const auto wei_gemmn_gemmk_desc = const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
...@@ -413,12 +392,10 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -413,12 +392,10 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename ELay> template <typename ELay>
__host__ __device__ static auto __host__ __device__ static auto
MakeEGridDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths, conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
e_g_n_k_wos_strides);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
...@@ -428,26 +405,27 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -428,26 +405,27 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Shape of Ds and E must be aligned. Strides can be different. // Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast. // Pass e_g_n_k_wos_lengths for logical broadcast.
__host__ __device__ static auto MakeDsGridDescriptor_M_N( static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(e_g_n_k_wos_lengths, return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
ds_g_n_k_wos_strides[i]);
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
// desc for problem definition // desc for problem definition
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>( constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; using AGridDesc_M_K =
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>; remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using BGridDesc_N_K =
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert // If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it // it to it
...@@ -533,7 +511,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -533,7 +511,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]}, num_group_{a_g_n_c_wis_lengths[0]},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths, conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
...@@ -542,12 +520,14 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -542,12 +520,14 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads)}, input_right_pads},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths, a_grid_desc_m_k_{
b_g_k_c_xs_strides)}, DeviceOp::MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_n_k_{
DeviceOp::MakeBGridDescriptor_N_K<BLayout>(conv_to_gemm_transformer_)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths, e_grid_desc_m_n_{
e_g_n_k_wos_strides)}, DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
...@@ -637,9 +617,20 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -637,9 +617,20 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// D batch stride // D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
ds_g_n_k_wos_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// D desc // D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>( ds_grid_desc_m_n_(i) =
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]); DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
}); });
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
...@@ -694,6 +685,9 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -694,6 +685,9 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp" #include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
...@@ -65,8 +64,8 @@ struct DeviceColumnToImageImpl ...@@ -65,8 +64,8 @@ struct DeviceColumnToImageImpl
static constexpr auto spatial_offset = Number<3>{}; static constexpr auto spatial_offset = Number<3>{};
static constexpr auto conv_to_gemm_transformer = using GemmToConvFwdTransformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>{}; TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{ MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{
MPerBlock, 0 /* NPerBlock*/, KPerBlock}; MPerBlock, 0 /* NPerBlock*/, KPerBlock};
...@@ -234,10 +233,7 @@ struct DeviceColumnToImageImpl ...@@ -234,10 +233,7 @@ struct DeviceColumnToImageImpl
: independent_filter_stride; : independent_filter_stride;
} }
// Calculate image form descriptor for the modified convolution problem GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>(
a_g_n_c_wis_lengths,
image_g_n_c_wis_strides, image_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor {}, // not needed for A Descriptor
...@@ -247,8 +243,11 @@ struct DeviceColumnToImageImpl ...@@ -247,8 +243,11 @@ struct DeviceColumnToImageImpl
independent_filter_strides, independent_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads_with_offset, input_left_pads_with_offset,
input_right_pads, input_right_pads};
N);
// Calculate image form descriptor for the modified convolution problem
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>();
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -182,18 +182,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -182,18 +182,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
#if 0
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
#endif
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
...@@ -206,121 +194,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -206,121 +194,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number could be One to Seven // Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{ {
#if 0
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
else
#endif
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{ {
...@@ -436,32 +309,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -436,32 +309,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number could be Odd or Even // Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{ {
#if 0
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
#endif
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
...@@ -487,32 +335,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -487,32 +335,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
} }
else else
{ {
#if 0
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
#endif
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
...@@ -542,18 +364,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -542,18 +364,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number always 1 // Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{ {
#if 0
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
#endif
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename DsDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockM,
index_t ScaleBlockN,
index_t ScaleBlockK,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB>
struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
: public DeviceGemmMultipleD_ABScale<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
DsDataType,
CDataType,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3<
ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
LDSTypeA,
LDSTypeB>;
using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
};
constexpr index_t minimum_occupancy =
(BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave &&
MPerBlock * NPerBlock / BlockSize > 64)
? 1
: 2;
if(has_main_k_block_loop)
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK != KPerBlock)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideC,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BScaleDataType*>(p_b_scale),
1,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const std::array<ck::index_t, NumDTensor> StrideDs,
const index_t StrideC,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BScaleDataType*>(p_b_scale),
1,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"}};
// clang-format off
str << "DeviceGemmXdlUniversal"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerXDL<<"x"<<NPerXDL << ", "
<< "WaveMap: "
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <typeinfo>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ReduceDataType = CDataType,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
ReduceDataType,
AElementwiseOperation,
BElementwiseOperation,
PassThrough,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
struct Argument : public GridwiseGemm::Argument
{
Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
const std::array<const void*, NumDTensor> p_ds_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<ck::index_t, NumDTensor> StrideDs_,
index_t StrideC_,
index_t k_batch_)
: GridwiseGemm::Argument(p_a_grid_,
p_b_grid_,
reinterpret_cast<ReduceDataType*>(p_c_grid_),
M_,
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
k_batch_,
true),
p_ds(p_ds_),
StrideDs(StrideDs_)
{
}
const std::array<const void*, NumDTensor> p_ds;
std::array<ck::index_t, NumDTensor> StrideDs;
};
using ReduceAdd = ck::reduce::Add;
using OutElementwiseOperation = CElementwiseOperation;
static constexpr auto DsVectorLengthSequence = generate_sequence_v2(
[](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same<CLayout, DLayout>::value)
return Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{};
else
return Number<1>{};
},
Number<NumDTensor>{});
using DeviceReduceInstance = DeviceReduceThreadWiseMultiD<
ReduceDataType, // InDataType,
DsDataType, // DsDatatype
GemmAccDataType, // AccDataType,
CDataType, // OutDataType,
3, // Rank
1, // NumReduceDim
ReduceAdd,
PassThrough,
OutElementwiseOperation,
256, // BlockSize_,
CShuffleBlockTransferScalarPerVector_NPerBlock, // MThreadSliceSize_,
1, // KThreadSliceSize_,
0, // InSrcVectorDim_,
CShuffleBlockTransferScalarPerVector_NPerBlock, // InSrcVectorSize_,
CShuffleBlockTransferScalarPerVector_NPerBlock, // OutDstVectorSize_
decltype(DsVectorLengthSequence)>;
// Invoker
struct Invoker : public BaseInvoker
{
float RunReduce(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
static constexpr index_t NumInDim = 3;
static constexpr index_t NumOutDim = 2;
std::array<ck::index_t, NumInDim> in_lengths = {arg.KBatch, arg.M, arg.N};
std::array<ck::index_t, NumOutDim> out_lengths = {arg.M, arg.N};
std::array<ck::index_t, NumInDim> in_strides;
std::array<ck::index_t, NumOutDim> out_strides;
if constexpr(std::is_same<CLayout, ck::tensor_layout::gemm::RowMajor>::value)
{
in_strides = {arg.M * arg.N, arg.N, 1};
out_strides = {arg.N, 1};
}
else
{
in_strides = {arg.M * arg.N, 1, arg.M};
out_strides = {1, arg.M};
}
std::array<int, 1> reduce_dims{0};
std::array<std::array<index_t, NumOutDim>, NumDTensor> DsLengths;
std::array<std::array<index_t, NumOutDim>, NumDTensor> DsStrides;
static_for<0, NumDTensor, 1>{}([&](auto i) {
DsLengths[i] = out_lengths;
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same<DLayout, ck::tensor_layout::gemm::RowMajor>::value)
{
DsStrides[i] = {arg.StrideDs[i], 1};
}
else
{
DsStrides[i] = {1, arg.StrideDs[i]};
}
});
auto reduce = DeviceReduceInstance{};
auto argument_ptr = reduce.MakeArgumentPointer(in_lengths,
in_strides,
DsLengths,
DsStrides,
out_lengths,
out_strides,
reduce_dims,
arg.p_workspace_,
arg.p_ds,
arg.p_c_grid,
PassThrough{},
OutElementwiseOperation{});
auto invoker_ptr = reduce.MakeInvokerPointer();
float ave_time = 0;
if(reduce.IsSupportedArgument(argument_ptr.get()))
{
ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config);
}
else
{
throw std::runtime_error(
"The runtime parameters seems not supported by the device instance, exiting!");
}
return ave_time;
}
float Run(const Argument& arg_, const StreamConfig& stream_config = StreamConfig{})
{
auto arg = *dynamic_cast<const typename GridwiseGemm::Argument*>(&arg_);
if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
std::is_same<CDataType, ReduceDataType>::value))
{
if(arg.p_workspace_ == nullptr)
{
throw std::runtime_error("using reduce , but empty workspace!");
}
arg.p_c_grid = reinterpret_cast<ReduceDataType*>(arg.p_workspace_);
}
if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
ck::utility::RotatingMemWrapper<typename GridwiseGemm::Argument> rotating_mem(
arg,
stream_config.rotating_count,
arg.M * arg.K * sizeof(ADataType),
arg.K * arg.N * sizeof(BDataType));
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg);
}
else
{
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
std::is_same<CDataType, ReduceDataType>::value))
{
// reduce c data
ave_time += RunReduce(arg_, stream_config);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
const std::array<const void*, NumDTensor> p_ds,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC,
index_t KBatch,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
{
return Argument{p_a, p_b, p_ds, p_c, M, N, K, StrideA, StrideB, StrideDs, StrideC, KBatch};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC,
index_t KBatch,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
KBatch);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemmXdlUniversalReduce"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerXDL<<"x"<<NPerXDL << ", "
<< "WaveMap: "
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
// clang-format on
return str.str();
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = *dynamic_cast<const Argument*>(p_arg);
if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
std::is_same<CDataType, ReduceDataType>::value))
{
std::cout << "using workspace" << std::endl;
return arg.M * arg.N * arg.KBatch * sizeof(ReduceDataType);
}
return 0;
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -238,37 +238,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -238,37 +238,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer = using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
template <typename ALay> template <typename ALay>
static auto static auto
MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths, conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -286,12 +266,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -286,12 +266,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
template <typename BLay> template <typename BLay>
static auto static auto
MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths, conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
b_g_k_c_xs_strides);
const auto wei_gemmn_gemmk_desc = const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
...@@ -309,13 +287,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -309,13 +287,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
} }
template <typename ELay> template <typename ELay>
static auto static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>( conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
...@@ -323,27 +298,27 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -323,27 +298,27 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
return out_gemmm_gemmn_desc; return out_gemmm_gemmn_desc;
} }
static auto MakeDsGridDescriptor_M_N( static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(ds_g_n_k_wos_lengths[i], return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
ds_g_n_k_wos_strides[i]);
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
// desc for problem definition // desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>; dummy_conv_to_gemm_transformer))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using DsGridDesc_M_N =
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = using GridwiseGemm =
...@@ -426,8 +401,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -426,8 +401,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]}, num_group_{a_g_n_c_wis_lengths[0]},
a_grid_desc_ak0_m_ak1_{ conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
...@@ -436,11 +410,13 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -436,11 +410,13 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads)}, input_right_pads},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>( a_grid_desc_ak0_m_ak1_{
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths, b_grid_desc_bk0_n_bk1_{
e_g_n_k_wos_strides)}, DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(conv_to_gemm_transformer_)},
e_grid_desc_m_n_{
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
a_grid_desc_k0_m0_m1_k1_{}, a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{}, b_grid_desc_k0_n0_n1_k1_{},
ds_grid_desc_m0_m10_m11_n0_n10_n11_{}, ds_grid_desc_m0_m10_m11_n0_n10_n11_{},
...@@ -471,6 +447,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -471,6 +447,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>; using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths[i],
ds_g_n_k_wos_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// D pointer // D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]); p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
...@@ -478,8 +465,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -478,8 +465,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
// D desc // D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>( ds_grid_desc_m_n_(i) =
ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
}); });
// populate desc for Ds/E // populate desc for Ds/E
...@@ -523,6 +510,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -523,6 +510,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
......
...@@ -234,37 +234,17 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -234,37 +234,17 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer = using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
template <typename ALay> template <typename ALay>
static auto static auto
MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths, conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
c_g_n_k_wos_lengths,
c_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -283,12 +263,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -283,12 +263,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
template <typename BLay> template <typename BLay>
static auto static auto
MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths, conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
b_g_k_c_xs_strides);
const auto wei_gemmn_gemmk_desc = const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
...@@ -306,13 +284,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -306,13 +284,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
} }
template <typename CLay> template <typename CLay>
static auto static auto MakeCGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeCGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>( conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>();
c_g_n_k_wos_lengths, c_g_n_k_wos_strides, c_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
...@@ -321,11 +296,13 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -321,11 +296,13 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
} }
// desc for problem definition // desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>; dummy_conv_to_gemm_transformer))>;
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N<CLayout>({}, {}))>; using CGridDesc_M_N =
remove_cvref_t<decltype(MakeCGridDescriptor_M_N<CLayout>(dummy_conv_to_gemm_transformer))>;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = using GridwiseGemm =
...@@ -396,21 +373,22 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -396,21 +373,22 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
p_b_grid_{static_cast<const BDataType*>(p_b)}, p_b_grid_{static_cast<const BDataType*>(p_b)},
p_c_grid_{static_cast<CDataType*>(p_c)}, p_c_grid_{static_cast<CDataType*>(p_c)},
num_group_{a_g_n_c_wis_lengths[0]}, num_group_{a_g_n_c_wis_lengths[0]},
a_grid_desc_ak0_m_ak1_{ conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
c_g_n_k_wos_lengths, e_g_n_k_wos_lengths,
c_g_n_k_wos_strides, e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads)}, input_right_pads},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>( a_grid_desc_ak0_m_ak1_{
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N<CLayout>(c_g_n_k_wos_lengths, b_grid_desc_bk0_n_bk1_{
c_g_n_k_wos_strides)}, DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(conv_to_gemm_transformer_)},
c_grid_desc_m_n_{
DeviceOp::MakeCGridDescriptor_M_N<CLayout>(conv_to_gemm_transformer_)},
a_grid_desc_k0_m0_m1_k1_{}, a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{}, b_grid_desc_k0_n0_n1_k1_{},
c_grid_desc_m0_m10_m11_n0_n10_n11_{}, c_grid_desc_m0_m10_m11_n0_n10_n11_{},
...@@ -473,6 +451,9 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -473,6 +451,9 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
......
...@@ -86,6 +86,7 @@ __global__ void ...@@ -86,6 +86,7 @@ __global__ void
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, const CDEElementwiseOperation cde_element_op,
const index_t groups_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -100,11 +101,14 @@ __global__ void ...@@ -100,11 +101,14 @@ __global__ void
defined(__gfx94__)) defined(__gfx94__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const index_t& num_blocks_per_n = groups_count;
const long_index_t e_group_offset = const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); const auto& ds_batch_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
const long_index_t e_n_offset = const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
...@@ -117,14 +121,14 @@ __global__ void ...@@ -117,14 +121,14 @@ __global__ void
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; }); [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
if constexpr(isMultiA || isMultiB) if constexpr(isMultiA || isMultiB)
{ {
AsPointer p_as_grid_grp; AsPointer p_as_grid_grp;
BsPointer p_bs_grid_grp; BsPointer p_bs_grid_grp;
const auto& as_group_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx); const auto& as_batch_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx);
// compute_ptr_offset_of_n_ not need BatchStrideB so // compute_ptr_offset_of_n_ not need BatchStrideB so
// in case of MultiA is false but isMultiB is true // in case of MultiA is false but isMultiB is true
...@@ -135,27 +139,27 @@ __global__ void ...@@ -135,27 +139,27 @@ __global__ void
static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
static_for<0, NumATensor, 1>{}([&](auto i) { static_for<0, NumATensor, 1>{}([&](auto i) {
p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + as_n_offset[i]; p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + as_n_offset[i];
}); });
} }
else else
{ {
const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
static_for<0, 1, 1>{}( static_for<0, 1, 1>{}(
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + a_n_offset; }); [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + a_n_offset; });
} }
const auto& bs_group_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx); const auto& bs_batch_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx);
static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size();
static_for<0, NumBTensor, 1>{}( static_for<0, NumBTensor, 1>{}(
[&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_group_offset[i]; }); [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
p_as_grid_grp, p_as_grid_grp,
p_bs_grid_grp, p_bs_grid_grp,
p_ds_grid_grp, p_ds_grid_grp,
p_e_grid + e_group_offset + e_n_offset, p_e_grid + e_batch_offset + e_n_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -168,19 +172,19 @@ __global__ void ...@@ -168,19 +172,19 @@ __global__ void
} }
else else
{ {
const long_index_t a_group_offset = const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_group_offset = const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t a_n_offset = const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
p_as_grid + a_group_offset + a_n_offset, p_as_grid + a_batch_offset + a_n_offset,
p_bs_grid + b_group_offset, p_bs_grid + b_batch_offset,
p_ds_grid_grp, p_ds_grid_grp,
p_e_grid + e_group_offset + e_n_offset, p_e_grid + e_batch_offset + e_n_offset,
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -196,6 +200,7 @@ __global__ void ...@@ -196,6 +200,7 @@ __global__ void
ignore = p_bs_grid; ignore = p_bs_grid;
ignore = p_ds_grid; ignore = p_ds_grid;
ignore = p_e_grid; ignore = p_e_grid;
ignore = groups_count;
ignore = a_grid_desc_k0_m_k1; ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1; ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
...@@ -282,8 +287,7 @@ template <index_t NDimSpatial, ...@@ -282,8 +287,7 @@ template <index_t NDimSpatial,
// in tuple for MultiAB), unpack if tuple was // in tuple for MultiAB), unpack if tuple was
// passed // passed
typename BComputeDataType = AComputeDataType, typename BComputeDataType = AComputeDataType,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler()>
index_t NumGroupsToMerge = 1>
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial, : public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
ALayout, ALayout,
...@@ -302,8 +306,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -302,8 +306,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
static_assert(NumGroupsToMerge >= 1);
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value; static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value; static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
...@@ -316,38 +318,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -316,38 +318,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer = using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization, NumGroupsToMerge>{}; ConvForwardSpecialization,
true /*SplitN*/,
ALayout,
ELayout>;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay> template <typename ALay>
static auto static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t Conv_N)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths, conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
Conv_N);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -356,13 +340,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -356,13 +340,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
} }
template <typename BLay> template <typename BLay>
static auto static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths, conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
b_g_k_c_xs_strides);
const auto wei_gemmn_gemmk_desc = const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
...@@ -371,14 +352,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -371,14 +352,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
} }
template <typename ELay> template <typename ELay>
static auto static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const index_t Conv_N)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>( conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
...@@ -388,27 +365,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -388,27 +365,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Shape of Ds and E must be aligned. Strides can be different. // Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast. // Pass e_g_n_k_wos_lengths for logical broadcast.
static auto MakeDsGridDescriptor_M_N( static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const index_t Conv_N)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>( return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], Conv_N);
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
// desc for problem definition // desc for problem definition
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>( constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>; using AGridDesc_M_K =
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>; remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, 1))>; using BGridDesc_N_K =
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}, 1))>; remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert // If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it // it to it
...@@ -496,13 +473,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -496,13 +473,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]}, num_group_{a_g_n_c_wis_lengths[0]},
conv_N_per_block_{ conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
conv_to_gemm_transformer.template GetSplitedNSize<ADataType, EDataType>(
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
...@@ -511,13 +482,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -511,13 +482,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads},
conv_N_per_block_)}, conv_N_per_block_{conv_to_gemm_transformer_.N_},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths, a_grid_desc_m_k_{
b_g_k_c_xs_strides)}, DeviceOp::MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_n_k_{
DeviceOp::MakeBGridDescriptor_N_K<BLayout>(conv_to_gemm_transformer_)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>( e_grid_desc_m_n_{
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)}, DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
...@@ -548,8 +521,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -548,8 +521,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{ {
static_for<0, NumATensor, 1>{}([&](auto i) { static_for<0, NumATensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_groups_ for multiple AB // Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_.BatchStrideA_(i) = compute_ptr_offset_of_groups_.BatchStrideA_(i) = a_g_n_c_wis_strides[0];
a_g_n_c_wis_strides[0] * NumGroupsToMerge;
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple) // type is not tuple)
...@@ -577,8 +549,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -577,8 +549,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}); });
static_for<0, NumBTensor, 1>{}([&](auto i) { static_for<0, NumBTensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_groups_ for multiple AB // Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_.BatchStrideB_(i) = compute_ptr_offset_of_groups_.BatchStrideB_(i) = b_g_k_c_xs_strides[0];
b_g_k_c_xs_strides[0] * NumGroupsToMerge;
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>; using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
// It is possible that one of the AB is a pointer and one is a tuple. // It is possible that one of the AB is a pointer and one is a tuple.
...@@ -598,10 +569,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -598,10 +569,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
} }
else else
{ {
compute_ptr_offset_of_groups_.BatchStrideA_ = compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
a_g_n_c_wis_strides[0] * NumGroupsToMerge; compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_groups_.BatchStrideB_ =
b_g_k_c_xs_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
// p_as and p_bs are pointers // p_as and p_bs are pointers
...@@ -618,16 +587,26 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -618,16 +587,26 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]); p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
// D batch stride // D batch stride
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
ds_g_n_k_wos_strides[i][0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideDs_(i) = compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_; ds_g_n_k_wos_strides[i][1] * conv_N_per_block_;
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
ds_g_n_k_wos_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// D desc // D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>( ds_grid_desc_m_n_(i) =
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], conv_N_per_block_); DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
}); });
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0] * NumGroupsToMerge; compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
// populate desc for Ds/E // populate desc for Ds/E
...@@ -690,6 +669,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -690,6 +669,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
index_t conv_N_per_block_; index_t conv_N_per_block_;
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
...@@ -748,8 +730,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -748,8 +730,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const index_t gdy = arg.num_group_ / NumGroupsToMerge; const index_t gdy = arg.num_group_ * num_workgroups_per_Conv_N;
const index_t gdz = num_workgroups_per_Conv_N; const index_t gdz = 1;
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
...@@ -798,6 +780,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -798,6 +780,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.cde_element_op_, arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
as_grid_desc_ak0_m_ak1, as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1, bs_grid_desc_bk0_n_bk1,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
...@@ -841,6 +824,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -841,6 +824,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.cde_element_op_, arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_, arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
...@@ -872,10 +856,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -872,10 +856,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{ {
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
// check device // check device
if(get_device_name() == "gfx908") if(get_device_name() == "gfx908")
{ {
...@@ -924,42 +904,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -924,42 +904,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
} }
} }
} }
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3)
{
if(C != 1)
{
return false;
}
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3];
if(filter_spatial_dim != I3)
{
return false;
}
}
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
{
return false;
}
}
if constexpr(NumGroupsToMerge > 1)
{
if(!(C == 1))
{
return false;
}
if(G % NumGroupsToMerge != 0)
{
return false;
}
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
{
return false;
}
}
// check vector access of A // check vector access of A
// FIXME: layout // FIXME: layout
...@@ -969,18 +913,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -969,18 +913,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> || is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC>) is_same_v<ALayout, ctc::NDHWGC>)
{ {
// Check access per C const index_t C = arg.a_g_n_c_wis_lengths_[2];
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
{
// If not possible, check access per G
if(!(ABlockTransferSrcVectorDim == 1 && C == 1 &&
is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>() &&
G % ABlockTransferSrcScalarPerVector == 0))
{ {
return false; return false;
} }
} }
}
else else
{ {
return false; return false;
...@@ -995,6 +934,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -995,6 +934,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<BLayout, ctc::KZYXGC>) is_same_v<BLayout, ctc::KZYXGC>)
{ {
const index_t C = arg.b_g_k_c_xs_lengths_[2];
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
{ {
return false; return false;
...@@ -1018,6 +959,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -1018,6 +959,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> || is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> ||
is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::G_K>) is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::G_K>)
{ {
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{ {
valid = false; valid = false;
...@@ -1062,6 +1005,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -1062,6 +1005,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> || is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
is_same_v<ELayout, ctc::NDHWGK>) is_same_v<ELayout, ctc::NDHWGK>)
{ {
const index_t K = arg.e_g_n_k_wos_lengths_[2];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{ {
return false; return false;
...@@ -1212,8 +1157,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -1212,8 +1157,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<< BBlockTransferSrcScalarPerVector << ", " << BBlockTransferSrcScalarPerVector << ", "
<< CDEBlockTransferScalarPerVector_NPerBlock << ", " << CDEBlockTransferScalarPerVector_NPerBlock << ", "
<< CShuffleMXdlPerWavePerShuffle << ", " << CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", " << CShuffleNXdlPerWavePerShuffle
<< NumGroupsToMerge
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -293,39 +293,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -293,39 +293,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer = using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{}; ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
EDataType>;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay> template <typename ALay>
static auto static auto
MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t Conv_N)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths, conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
Conv_N);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -344,12 +327,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -344,12 +327,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <typename BLay> template <typename BLay>
static auto static auto
MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths, conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
b_g_k_c_xs_strides);
const auto wei_gemmn_gemmk_desc = const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
...@@ -367,15 +348,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -367,15 +348,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
} }
template <typename ELay> template <typename ELay>
static auto static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const index_t Conv_N)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>( conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
...@@ -384,7 +361,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -384,7 +361,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
} }
// desc for problem definition // desc for problem definition
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}, 1))>; constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
#define GridwiseGemmV3TemplateParams \ #define GridwiseGemmV3TemplateParams \
tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \ tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \
...@@ -417,9 +396,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -417,9 +396,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>; dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>; dummy_conv_to_gemm_transformer))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>; EGridDesc_M_N{}))>;
...@@ -450,13 +429,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -450,13 +429,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
p_b_grid_{}, p_b_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]}, num_group_{a_g_n_c_wis_lengths[0]},
conv_N_per_block_{ conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
conv_to_gemm_transformer.template GetSplitedNSize<ADataType, EDataType>(
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_ak0_m_ak1_{MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
...@@ -465,12 +438,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -465,12 +438,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads},
conv_N_per_block_)}, conv_N_per_block_{conv_to_gemm_transformer_.N_},
a_grid_desc_ak0_m_ak1_{
MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
MakeBGridDescriptor_BK0_N_BK1<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, MakeBGridDescriptor_BK0_N_BK1<BLayout>(conv_to_gemm_transformer_)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>( e_grid_desc_m_n_{
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)}, DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
compute_ptr_offset_of_groups_{}, compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{}, compute_ptr_offset_of_n_{},
...@@ -519,6 +494,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 ...@@ -519,6 +494,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
index_t conv_N_per_block_; index_t conv_N_per_block_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
......
...@@ -309,37 +309,16 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -309,37 +309,16 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer = using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay> template <typename ALay>
static auto static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths, conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -348,13 +327,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -348,13 +327,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
} }
template <typename BLay> template <typename BLay>
static auto static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths, conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
b_g_k_c_xs_strides);
const auto wei_gemmn_gemmk_desc = const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
...@@ -363,13 +339,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -363,13 +339,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
} }
template <typename ELay> template <typename ELay>
static auto static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>( conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
...@@ -447,10 +420,13 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -447,10 +420,13 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo); return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
} }
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>( constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; using AGridDesc_M_K =
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>; remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<DELayout>({}, {}))>; using BGridDesc_N_K =
remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<DELayout>(dummy_conv_to_gemm_transformer))>;
using RGridDesc_M = remove_cvref_t<decltype(MakeRGridDescriptor_M<RLayout>({}, {}))>; using RGridDesc_M = remove_cvref_t<decltype(MakeRGridDescriptor_M<RLayout>({}, {}))>;
// GridwiseGemm // GridwiseGemm
...@@ -551,7 +527,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -551,7 +527,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
p_rs_grid_{}, // FIXME p_rs_grid_{}, // FIXME
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths, conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
...@@ -560,12 +536,14 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -560,12 +536,14 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads)}, input_right_pads},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths, a_grid_desc_m_k_{
b_g_k_c_xs_strides)}, DeviceOp::MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_n_k_{
DeviceOp::MakeBGridDescriptor_N_K<BLayout>(conv_to_gemm_transformer_)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<DELayout>(e_g_n_k_wos_lengths, e_grid_desc_m_n_{
e_g_n_k_wos_strides)}, DeviceOp::MakeEGridDescriptor_M_N<DELayout>(conv_to_gemm_transformer_)},
r_grid_desc_m_{ r_grid_desc_m_{
DeviceOp::MakeRGridDescriptor_M<RLayout>(r_g_n_wos_lengths, r_g_n_wos_strides)}, DeviceOp::MakeRGridDescriptor_M<RLayout>(r_g_n_wos_lengths, r_g_n_wos_strides)},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
...@@ -621,9 +599,20 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -621,9 +599,20 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
// D batch stride // D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths[i],
ds_g_n_k_wos_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// D desc // D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DELayout>( ds_grid_desc_m_n_(i) =
ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); DeviceOp::MakeEGridDescriptor_M_N<DELayout>(conv_to_gemm_transformer_d);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
...@@ -660,6 +649,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -660,6 +649,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
EDataType* p_e_grid_; EDataType* p_e_grid_;
typename GridwiseGemm::RsGridPointer p_rs_grid_; typename GridwiseGemm::RsGridPointer p_rs_grid_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
......
...@@ -135,36 +135,16 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -135,36 +135,16 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static constexpr auto BEnableLds = static constexpr auto BEnableLds =
BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1); BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1);
static constexpr auto conv_to_gemm_transformer = using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock}; MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay> template <typename ALay>
static auto MakeAGridDescriptor(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, static auto MakeAGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
{ {
const auto in_gemmmraw_gemmkraw_desc = const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths, conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
...@@ -205,12 +185,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -205,12 +185,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
} }
template <typename BLay> template <typename BLay>
static auto MakeBGridDescriptor(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, static auto MakeBGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
{ {
const auto wei_gemmnraw_gemmkraw_desc = const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths, conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
b_g_k_c_xs_strides);
const auto wei_gemmn_gemmk_desc = const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
...@@ -251,13 +229,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -251,13 +229,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
} }
template <typename ELay> template <typename ELay>
static auto static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{ {
const auto out_gemmmraw_gemmnraw_desc = const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>( conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc = const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
...@@ -265,26 +240,27 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -265,26 +240,27 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
return out_gemmm_gemmn_desc; return out_gemmm_gemmn_desc;
} }
static auto MakeDsGridDescriptor_M_N( static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(ds_g_n_k_wos_lengths[i], return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
ds_g_n_k_wos_strides[i]);
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
// desc for problem definition // desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc = using AGridDesc =
decltype(DeviceOp::MakeAGridDescriptor<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {})); decltype(DeviceOp::MakeAGridDescriptor<ALayout>(dummy_conv_to_gemm_transformer));
using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor<BLayout>({}, {})); using BGridDesc =
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; decltype(DeviceOp::MakeBGridDescriptor<BLayout>(dummy_conv_to_gemm_transformer));
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
// GridwiseOp // GridwiseOp
using GridwiseOp = GridwiseGemmMultipleD_Wmma< using GridwiseOp = GridwiseGemmMultipleD_Wmma<
...@@ -373,10 +349,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -373,10 +349,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)}, p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]}, num_group_{a_g_n_c_wis_lengths[0]},
ds_grid_desc_m_n_{}, conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_{DeviceOp::MakeAGridDescriptor<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides, a_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
b_g_k_c_xs_strides, b_g_k_c_xs_strides,
...@@ -385,9 +358,12 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -385,9 +358,12 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads)}, input_right_pads},
b_grid_desc_{ ds_grid_desc_m_n_{},
DeviceOp::MakeBGridDescriptor<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, e_grid_desc_m_n_{
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
a_grid_desc_{DeviceOp::MakeAGridDescriptor<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_{DeviceOp::MakeBGridDescriptor<BLayout>(conv_to_gemm_transformer_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)}, block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)},
...@@ -426,8 +402,24 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -426,8 +402,24 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}); });
// D desc // D desc
ds_grid_desc_m_n_ = ds_grid_desc_m_n_ = generate_tuple(
DeviceOp::MakeDsGridDescriptor_M_N(ds_g_n_k_wos_lengths, ds_g_n_k_wos_strides); [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths[i],
ds_g_n_k_wos_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
},
Number<NumDTensor>{});
// populate desc for Ds/E // populate desc for Ds/E
e_grid_desc_mblock_mperblock_nblock_nperblock_ = e_grid_desc_mblock_mperblock_nblock_nperblock_ =
...@@ -455,6 +447,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -455,6 +447,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
index_t num_group_; index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
DsGridDesc_M_N ds_grid_desc_m_n_; DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
......
...@@ -57,8 +57,8 @@ struct DeviceImageToColumnImpl ...@@ -57,8 +57,8 @@ struct DeviceImageToColumnImpl
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto conv_to_gemm_transformer = using GemmToConvFwdTransformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>{}; TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
static constexpr auto matrix_padder = static constexpr auto matrix_padder =
MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{ MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{
...@@ -97,9 +97,7 @@ struct DeviceImageToColumnImpl ...@@ -97,9 +97,7 @@ struct DeviceImageToColumnImpl
b_g_k_c_xs_lengths[I2] = C; b_g_k_c_xs_lengths[I2] = C;
c_g_n_k_wos_lengths[I1] = N; c_g_n_k_wos_lengths[I1] = N;
const auto in_gemmmraw_gemmkraw_desc = GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>(
a_g_n_c_wis_lengths,
image_g_n_c_wis_strides, image_g_n_c_wis_strides,
b_g_k_c_xs_lengths, b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor {}, // not needed for A Descriptor
...@@ -108,8 +106,10 @@ struct DeviceImageToColumnImpl ...@@ -108,8 +106,10 @@ struct DeviceImageToColumnImpl
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads};
N);
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>();
const auto in_gemmm_gemmk_desc = const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <array>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename DsDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename OutElementwiseOperation,
index_t BlockSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize,
typename DsVectorSizeSequence>
struct DeviceReduceThreadWiseMultiD : public DeviceReduceMultiD<InDataType,
DsDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
OutElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t NumSrcDim = Rank;
static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
const std::array<index_t, Rank>& inStrides)
{
const auto tupleSrcLengths =
generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
const auto tupleSrcStrides =
generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = generate_tuple(
[&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
const auto invariantDimLengths =
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K =
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
const std::array<index_t, NumDstDim>& outStrides)
{
const auto tupleDstLengths =
generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto outPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = transform_tensor_descriptor(
out_grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, outPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
static auto
MakeDsDescriptor(const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides)
{
return generate_tuple(
[&](auto i) {
return DeviceReduceThreadWiseMultiD::MakeDst1dDescriptor(DsLengths[i],
DsStrides[i]);
},
Number<NumDTensor>{});
}
using InGridDesc_M_K = decltype(MakeSrc2dDescriptor({}, {}));
using OutGridDesc_M = decltype(MakeDst1dDescriptor({}, {}));
using DsGridDesc_M = decltype(MakeDsDescriptor({}, {}));
using GridwiseReduce =
GridwiseReduction_mk_to_m_threadwise_multi_d<InDataType,
DsDataType,
OutDataType,
AccDataType,
InGridDesc_M_K,
DsGridDesc_M,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
OutElementwiseOperation,
InMemoryDataOperationEnum::Set,
BlockSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize,
DsVectorSizeSequence>;
using DsGridPointer = typename GridwiseReduce::DsGridPointer;
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides,
const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides,
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
const InDataType* in_dev,
const std::array<const void*, NumDTensor> ds_dev,
OutDataType* out_dev,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation out_elementwise_op)
: DsLengths_{DsLengths},
DsStrides_{DsStrides},
outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
in_elementwise_op_{in_elementwise_op},
out_elementwise_op_{out_elementwise_op}
{
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1;
else
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths_[Rank - 1];
numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_(i) = static_cast<const DDataType*>(ds_dev[i]);
});
ds_grid_desc_m_ = MakeDsDescriptor(DsLengths, DsStrides);
}
std::array<index_t, Rank> inLengths_;
std::array<index_t, Rank> inStrides_;
std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths_;
std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides_;
std::array<index_t, NumDstDim> outLengths_;
std::array<index_t, NumDstDim> outStrides_;
const InDataType* in_dev_;
OutDataType* out_dev_;
DsGridPointer p_ds_grid_;
InElementwiseOperation in_elementwise_op_;
OutElementwiseOperation out_elementwise_op_;
DsGridDesc_M ds_grid_desc_m_;
index_t invariant_lowest_length;
index_t reduce_lowest_length;
long_index_t invariant_total_length;
long_index_t reduce_total_length;
int numBlockTileIteration;
size_t gridSize;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k =
DeviceReduceThreadWiseMultiD::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
const auto out_grid_desc_m =
DeviceReduceThreadWiseMultiD::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_);
float avg_time = 0;
const auto kernel = kernel_reduce_threadwise_multi_d<GridwiseReduce,
InDataType,
OutDataType,
AccDataType,
InGridDesc_M_K,
DsGridDesc_M,
OutGridDesc_M,
InElementwiseOperation,
OutElementwiseOperation,
DsGridPointer>;
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
arg.ds_grid_desc_m_,
out_grid_desc_m,
arg.in_elementwise_op_,
arg.out_elementwise_op_,
arg.in_dev_,
arg.p_ds_grid_,
arg.out_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(InSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
{
return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false);
};
}
else
{
if(pArg->inStrides_[Rank - 1] != 1)
return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
return (false);
};
// To improve
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false);
std::cerr << "reduce_total_length = " << pArg->reduce_total_length
<< " KThreadSliceSize = " << KThreadSliceSize << std::endl;
// cases with big reduce_total_length should be handled by Blockwise kernel
if(pArg->reduce_total_length / KThreadSliceSize >= 32)
return (false);
return (true);
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides,
const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides,
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
const void* in_dev,
const std::array<const void*, NumDTensor> ds_dev,
void* out_dev,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation out_elementwise_op) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
DsLengths,
DsStrides,
outLengths,
outStrides,
reduceDims,
static_cast<const InDataType*>(in_dev),
ds_dev,
static_cast<OutDataType*>(out_dev),
in_elementwise_op,
out_elementwise_op);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceThreadWiseMultiD<" << BlockSize << ",";
str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << 1 << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -249,6 +249,31 @@ struct MultiplyAdd ...@@ -249,6 +249,31 @@ struct MultiplyAdd
} }
}; };
struct MultiplyMultiply
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
ck::half_t& e, const float& c, const float& d0, const float& d1) const
{
const float x0_f = c * d0 * d1;
e = ck::type_convert<ck::half_t>(x0_f);
}
template <>
__host__ __device__ constexpr void operator()<ck::bhalf_t, float, float, float>(
ck::bhalf_t& e, const float& c, const float& d0, const float& d1) const
{
const float x0_f = c * d0 * d1;
e = ck::type_convert<ck::bhalf_t>(x0_f);
}
};
struct MultiplyAddFastGelu struct MultiplyAddFastGelu
{ {
template <typename E, typename C, typename D0, typename D1> template <typename E, typename C, typename D0, typename D1>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/tuple_helper.hpp"
namespace ck {
template <typename GridwiseReduction,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename InGridDesc_M_K,
typename DsGridDesc_M,
typename OutGridDesc_M,
typename InElementwiseOperation,
typename OutElementwiseOperation,
typename DsGridPointer>
__global__ void
kernel_reduce_threadwise_multi_d(const InGridDesc_M_K in_grid_desc_m_k,
const DsGridDesc_M ds_grid_desc_m,
const OutGridDesc_M out_grid_desc_m,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation out_elementwise_op,
const InDataType* const __restrict__ p_in_value_global,
const DsGridPointer p_ds_value_global,
OutDataType* const __restrict__ p_out_value_global)
{
GridwiseReduction::Run(in_grid_desc_m_k,
ds_grid_desc_m,
out_grid_desc_m,
in_elementwise_op,
out_elementwise_op,
p_in_value_global,
p_ds_value_global,
p_out_value_global);
}
template <typename InDataType,
typename DsDataType,
typename OutDataType,
typename AccDataType,
typename InGridDesc_M_K,
typename DsGridDesc_M,
typename OutGridDesc_M,
typename ReduceOperation,
typename InElementwiseOperation,
typename OutElementwiseOperation,
InMemoryDataOperationEnum OutMemoryDataOperation,
index_t BlockSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize,
typename DsVectorSize>
struct GridwiseReduction_mk_to_m_threadwise_multi_d
{
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using ThreadBufferDimAccessOrder =
typename conditional<InSrcVectorDim == 0, Sequence<1, 0>, Sequence<0, 1>>::type;
using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{})));
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using PassThrough = tensor_operation::element_wise::PassThrough;
static constexpr auto I0 = Number<0>{};
static constexpr index_t NumDTensor = DsDataType::Size();
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
using DsGridPointer = decltype(MakeDsGridPointer());
__device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
const DsGridDesc_M& ds_grid_desc_m,
const OutGridDesc_M& out_grid_desc_m,
const InElementwiseOperation& in_elementwise_op,
const OutElementwiseOperation& out_elementwise_op,
const InDataType* const __restrict__ p_in_value_global,
const DsGridPointer p_ds_grid,
OutDataType* const __restrict__ p_out_value_global)
{
using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
ReduceOperation,
false>;
const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(),
ReduceOperation::template GetIdentityValue<InDataType>());
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize * KThreadSliceSize, true>
in_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, MThreadSliceSize, true> accu_value_buf;
static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MThreadSliceSize>{}, Number<KThreadSliceSize>{}));
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_val_load =
ThreadwiseTensorSliceTransfer_v2<InDataType,
AccDataType,
InGridDesc_M_K,
decltype(thread_buffer_desc),
ThreadBufferLengths,
ThreadBufferDimAccessOrder,
InSrcVectorDim,
InSrcVectorSize,
1,
false>(
in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
index_t reducedLength = 0;
do
{
threadwise_src_val_load.Run(in_grid_desc_m_k,
in_global_val_buf,
thread_buffer_desc,
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
reducedLength += KThreadSliceSize;
} while(reducedLength < toReduceLength);
constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
auto ds_thread_buf = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(DsGridPointer{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, MThreadSliceSize, true>{};
},
Number<NumDTensor>{});
auto ds_global_buf = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[I], ds_grid_desc_m[I].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto ds_global_load = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(DsGridPointer{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return ThreadwiseTensorSliceTransfer_v2<DataType,
DataType,
decltype(ds_grid_desc_m[I]),
decltype(reduced_data_desc),
Sequence<MThreadSliceSize>, // SliceLengths
Sequence<0>, // DimAccessOrder
InSrcVectorDim, // SrcVectorDim
DsVectorSize{}[I],
1, // SrcScalarStrideInVector
true>{
ds_grid_desc_m[I], make_multi_index(thread_global_1d_id * MThreadSliceSize)};
},
Number<NumDTensor>{});
static_for<0, NumDTensor, 1>{}([&](auto I) {
ds_global_load(I).Run(ds_grid_desc_m[I],
ds_global_buf[I],
reduced_data_desc,
make_tuple(I0),
ds_thread_buf(I));
});
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true> out_value_buf;
// if constexpr(NumDTensor > 0)
{
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(accu_value_buf[I]),
generate_tie(
[&](auto Id) -> const auto& { return ds_thread_buf[Id][I]; },
Number<NumDTensor>{}));
unpack2(out_elementwise_op, tie(out_value_buf(I)), c_ds_buf_refs);
});
}
auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3<OutDataType,
OutDataType,
decltype(reduced_data_desc),
OutGridDesc_M,
PassThrough,
Sequence<MThreadSliceSize>,
Sequence<0>,
0,
OutDstVectorSize,
OutMemoryDataOperation,
1,
false>(
out_grid_desc_m,
make_multi_index(thread_global_1d_id * MThreadSliceSize),
PassThrough{});
threadwise_dst_store.Run(
reduced_data_desc, make_tuple(I0), out_value_buf, out_grid_desc_m, dst_global_buf);
}
};
} // namespace ck
...@@ -42,7 +42,7 @@ __global__ void ...@@ -42,7 +42,7 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid, karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
p_shared, p_shared,
karg); karg);
#else #else
...@@ -73,7 +73,7 @@ __global__ void ...@@ -73,7 +73,7 @@ __global__ void
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid, karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
p_shared_0, p_shared_0,
p_shared_1, p_shared_1,
karg); karg);
...@@ -531,21 +531,35 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -531,21 +531,35 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_, index_t StrideC_,
index_t k_batch_) index_t k_batch_,
bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
p_a_grid{p_a_grid_}, p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_}, p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_} p_c_grid{p_c_grid_},
is_reduce(is_reduce_)
{ {
} }
__host__ __device__ inline bool IsReduceAdd() const
{
return (Problem::KBatch > 1) && is_reduce;
}
__host__ __device__ inline bool IsAtomicAdd() const
{
return (Problem::KBatch > 1) && (!is_reduce);
}
const ADataType* p_a_grid; const ADataType* p_a_grid;
const BDataType* p_b_grid; const BDataType* p_b_grid;
CDataType* p_c_grid; CDataType* p_c_grid;
bool is_reduce;
}; };
struct SplitKBatchOffset struct SplitKBatchOffset
{ {
__device__ SplitKBatchOffset(Argument& karg) __device__ SplitKBatchOffset(Argument& karg)
{ {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
...@@ -574,10 +588,20 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -574,10 +588,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
{ {
karg.K = karg.K - karg.KRead * (karg.KBatch - 1); karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
} }
if(karg.IsReduceAdd())
{
c_reduce_offset = blockIdx.z * karg.M * karg.N;
}
else
{
c_reduce_offset = 0;
}
} }
index_t a_k_split_offset; index_t a_k_split_offset;
index_t b_k_split_offset; index_t b_k_split_offset;
index_t c_reduce_offset;
}; };
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
...@@ -1080,7 +1104,10 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1080,7 +1104,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
} }
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value) if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
is_same<remove_cvref_t<CDataType>, float>::value))
{
if(!karg.IsReduceAdd())
{ {
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
...@@ -1092,6 +1119,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1092,6 +1119,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
return false; return false;
} }
} }
}
// check gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#define DEBUG_LOG 0
namespace ck {
// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
// kernel function Blockers:
// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
// two lds chunks.
// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
// buffer when we declare __shared__ inside blkgemmpipe
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid,
karg.p_b_grid,
karg.p_ds_grid,
karg.p_c_grid,
karg.p_a_scale_grid,
karg.p_b_scale_grid,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockM,
index_t ScaleBlockN,
index_t ScaleBlockK,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ADataType,
typename LDSTypeB = BDataType>
struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
{
using AScaleType = float;
using BScaleType = float;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock =
CDEShuffleBlockTransferScalarPerVectors{}[I0];
// K1 should be Number<...>
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
using DsGridPointer = decltype(MakeDsGridPointer());
static constexpr index_t KPack = math::max(
math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
}
__host__ static auto CalculateMPadded(index_t M)
{
return math::integer_least_multiple(M, MPerBlock);
}
__host__ static auto CalculateNPadded(index_t N)
{
return math::integer_least_multiple(N, NPerBlock);
}
__host__ static auto CalculateKPadded(index_t K)
{
return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
}
__host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
}
__host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
}
__host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * KPerBlock;
}
__host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
{
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
auto K_t = K_Batch * KReadVec;
return (K + K_t - 1) / K_t * KReadVec;
}
__host__ static auto CalculateMBlock(index_t M)
{
return math::integer_divide_ceil(M, MPerBlock);
}
__host__ static auto CalculateNBlock(index_t N)
{
return math::integer_divide_ceil(N, NPerBlock);
}
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
__host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
{
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
return transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
make_unmerge_transform(make_tuple(
Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(N, NPad - N),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto
MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
}
template <typename BBlockDesc_BK0_N_BK1>
__host__ __device__ static constexpr auto
MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
{
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
}
template <typename ELayout>
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
}
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeCGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
},
Number<NumDTensor>{});
}
template <typename DsGridDesc>
__device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
return generate_tuple(
[&](auto i) {
return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n[i], MBlock, NBlock);
},
Number<NumDTensor>{});
}
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
struct Problem
{
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideC_,
index_t KBatch_)
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideDs{StrideDs_},
StrideC{StrideC_},
KBatch{KBatch_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
KRead{CalculateKRead(K_, KBatch_)},
KPadded{CalculateKPadded(K_, KBatch_)},
AK0{CalculateAK0Padded(K_, KBatch_)},
BK0{CalculateBK0Padded(K_, KBatch_)},
MBlock{CalculateMBlock(M_)},
NBlock{CalculateNBlock(N_)}
{
}
__host__ void Print() const
{
std::cout << "problem {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KRead:" << KRead << ", "
<< "KP:" << KPadded << ", "
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << "}" << std::endl;
}
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideC;
index_t KBatch;
index_t MPadded;
index_t NPadded;
index_t KRead;
index_t KPadded;
index_t AK0;
index_t BK0;
index_t MBlock;
index_t NBlock;
};
// Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem
{
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideC_,
const AScaleType* p_a_scale_grid_,
const BScaleType* p_b_scale_grid_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_ds_grid{},
p_c_grid{p_c_grid_},
p_a_scale_grid{p_a_scale_grid_},
p_b_scale_grid{p_b_scale_grid_},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
c_element_op{c_element_op_}
{
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
// D pointer
p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
});
}
const ADataType* p_a_grid;
const BDataType* p_b_grid;
DsGridPointer p_ds_grid;
CDataType* p_c_grid;
const AScaleType* p_a_scale_grid;
const BScaleType* p_b_scale_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
const CElementwiseOperation c_element_op;
};
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(Argument& karg)
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead * karg.M;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.N;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead;
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
{
karg.K = karg.KRead;
}
else
{
karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
};
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock + ABlockLdsExtraM>{}, I1));
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1
? 1
: 32 * 4 / KPerBlock / sizeof(LDSTypeA);
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock * MLdsLayer>{}, I1));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(make_xor_with_modulo_transform(make_tuple(
Number<MPerBlock / MLdsLayer>{}, Number<AK0Number * MLdsLayer>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number<MLdsLayer>{})),
make_pass_through_transform(Number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}));
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_ak0_mldslayer_m_ak1,
make_tuple(make_pass_through_transform(AK0Number),
make_merge_transform_v3_division_mod(
make_tuple(Number<MPerBlock / MLdsLayer>{}, Number<MLdsLayer>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_lds_block_desc_ak0_m_ak1;
}
else // ColumnMajor A
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto M1 = MPerBlock / M0;
constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
constexpr auto KThreadRead = 64 / MPerXdl;
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
? 1
: 128 / (AK1Number * M0 * sizeof(LDSTypeA));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mpair<=n0
constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
? 1
: ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
? M0
: 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<K0PerThreadWrite>{},
Number<KThreadReadPerm * M1>{},
Number<kfold * M0 / mpair>{},
Number<mpair>{},
AK1Number));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_xor_with_modulo_transform(
make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
make_pass_through_transform(Number<mpair>{}),
make_pass_through_transform(AK1Number)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}));
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(Number<KThreadReadPerm>{}, Number<M1>{})),
make_unmerge_transform(make_tuple(Number<kfold>{}, Number<M0 / mpair>{})),
make_pass_through_transform(Number<mpair>{}),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<1>{},
Sequence<2>{},
Sequence<0, 3>{},
Sequence<4, 5>{},
Sequence<6>{},
Sequence<7>{}));
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<KThreadReadPerm>{},
Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<kfold>{},
Number<K0PerThreadWrite>{})),
make_merge_transform_v3_division_mod(
make_tuple(Number<M0 / mpair>{}, Number<mpair>{}, Number<M1>{})),
make_pass_through_transform(AK1Number)),
make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return a_lds_block_desc_ak0_m_ak1;
}
}
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(BK1Number, Number<KPerBlock + BBlockLdsExtraN>{}, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
// NLdsLayer * K0 as logical Bank
constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1
? 1
: 32 * 4 / KPerBlock / sizeof(LDSTypeB);
;
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(
BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
make_tuple(BK1Number, Number<KPerBlock * NLdsLayer>{}, I1));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(make_xor_with_modulo_transform(make_tuple(
Number<NPerBlock / NLdsLayer>{}, Number<BK0Number * NLdsLayer>{})),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number<NLdsLayer>{})),
make_pass_through_transform(Number<NPerBlock / NLdsLayer>{}),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}));
constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_bk0_nldslayer_n_bk1,
make_tuple(make_pass_through_transform(BK0Number),
make_merge_transform_v3_division_mod(
make_tuple(Number<NPerBlock / NLdsLayer>{}, Number<NLdsLayer>{})),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_lds_block_desc_bk0_n_bk1;
}
else // RowMajor B
{
constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
constexpr auto N1 = NPerBlock / N0;
constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
constexpr auto KThreadRead = 64 / NPerXdl;
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128)
? 1
: 128 / (BK1Number * N0 * sizeof(LDSTypeB));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=n0
constexpr auto npair = (BK1Number * NPerXdl * sizeof(LDSTypeB) > 128)
? 1
: ((128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))) > N0
? N0
: 128 / (BK1Number * NPerXdl * sizeof(LDSTypeB)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<K0PerThreadWrite>{},
Number<KThreadReadPerm * N1>{},
Number<kfold * N0 / npair>{},
Number<npair>{},
BK1Number));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_xor_with_modulo_transform(
make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
make_pass_through_transform(Number<npair>{}),
make_pass_through_transform(BK1Number)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}));
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(Number<KThreadReadPerm>{}, Number<N1>{})),
make_unmerge_transform(make_tuple(Number<kfold>{}, Number<N0 / npair>{})),
make_pass_through_transform(Number<npair>{}),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<1>{},
Sequence<2>{},
Sequence<0, 3>{},
Sequence<4, 5>{},
Sequence<6>{},
Sequence<7>{}));
constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(Number<KThreadReadPerm>{},
Number<KThreadWrite / kfold / KThreadReadPerm>{},
Number<kfold>{},
Number<K0PerThreadWrite>{})),
make_merge_transform_v3_division_mod(
make_tuple(Number<N0 / npair>{}, Number<npair>{}, Number<N1>{})),
make_pass_through_transform(BK1Number)),
make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return b_lds_block_desc_bk0_n_bk1;
}
}
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
using BlockwiseGemmPipe =
remove_cvref_t<decltype(BlockGemmABScalePipeline_Selector<
BlkGemmPipelineVer,
BlkGemmPipeSched,
BlockSize,
LDSTypeA,
LDSTypeB,
ComputeTypeA,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()),
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())),
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())),
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>())>;
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) +
b_block_space_size_aligned * sizeof(LDSTypeB)),
c_block_size * sizeof(CShuffleDataType));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ static constexpr bool CheckValidity(const Argument& karg)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.M % MPerBlock == 0))
{
#if DEBUG_LOG
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.N % NPerBlock == 0))
{
#if DEBUG_LOG
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
auto K_t = karg.KBatch * KPerBlock;
if(!(karg.K % K_t == 0))
{
#if DEBUG_LOG
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
<< karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
else
{
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
auto K_t = karg.KBatch * KReadVec;
auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
else
{
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
else
{
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
{
#if DEBUG_LOG
std::cout << "Arg K (" << karg.K
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
#if DEBUG_LOG
std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
else
{
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
#if DEBUG_LOG
std::cout << "Arg M (" << karg.M
<< ") value is not a multiple of "
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
return false;
}
}
// check gridwise gemm pipeline
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
{
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
{
return false;
}
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
}
__host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
{
const index_t num_loop = K / KPerBlock;
return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
}
template <typename CGridDesc>
__device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
{
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return c_grid_desc_mblock_mperblock_nblock_nperblock;
}
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
const AScaleType* p_a_scale_grid,
const BScaleType* p_b_scale_grid,
void* p_shared,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
LDSTypeA,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
LDSTypeB,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
// Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeB*>(p_shared) +
a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
// Blockwise GEMM pipeline
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
const index_t ScaleSliceSizeM = 1;
const index_t ScaleSliceSizeN = 1;
const index_t ScaleSliceSizeK = 1;
constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
auto a_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<AScaleType,
AScaleType,
decltype(a_scale_grid_desc_am_ak),
decltype(a_scale_thread_desc),
Sequence<ScaleSliceSizeM, ScaleSliceSizeK>,
Sequence<0, 1>,
1,
1,
1,
false>(
a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM, 0));
auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc),
Sequence<ScaleSliceSizeN, ScaleSliceSizeK>,
Sequence<0, 1>,
1,
1,
1,
false>(
b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, 1);
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
c_thread_buf,
a_scale_grid_desc_am_ak,
a_scale_thread_desc,
a_scale_thread_copy,
a_scale_grid_buf,
a_scale_thread_slice_copy_step,
b_scale_grid_desc_bn_ak,
b_scale_thread_desc,
b_scale_thread_copy,
b_scale_grid_buf,
b_scale_thread_slice_copy_step,
num_k_block_main_loop,
num_k_block_per_scale);
// shuffle C and write out
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
// TODO: hacky, fix it!
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// shuffle: threadwise copy C from VGPR to LDS
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
using EDataType = CDataType;
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
c_grid_desc_mblock_mperblock_nblock_nperblock;
using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
3, // index_t SrcVectorDim,
3, // index_t DstVectorDim,
CDEShuffleBlockTransferScalarPerVectors,
CShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
c_element_op};
// space filling curve for threadwise C in VGPR
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
// space filling curve for shuffled blockwise C/D/E
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
block_sync_lds();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// make sure it's safe to read from LDS
block_sync_lds();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(c_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// move on E
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
}
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment