Commit a781d078 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Merge branch 'develop' into bnorm_bwd_pr

parents fd76c787 4c4c7328
...@@ -14,39 +14,38 @@ namespace device { ...@@ -14,39 +14,38 @@ namespace device {
// Convolution Forward: // Convolution Forward:
// input : input image A[G, N, C, Hi, Wi], // input : input image A[G, N, C, Hi, Wi],
// input : weight B[G, K, C, Y, X], // input : weight B[G, K, C, Y, X],
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
// output : output image E[G, N, K, Ho, Wo] // output : output image E[G, N, K, Ho, Wo]
// C = a_op(A) * b_op(B) // C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...) // E = cde_op(C, D0, D1, ...)
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename ALayout, typename InLayout,
typename BLayout, typename WeiLayout,
typename CLayout, typename OutLayout,
typename ADataType, typename InDataType,
typename BDataType, typename WeiDataType,
typename CDataType, typename OutDataType,
typename AElementwiseOperation, typename InElementwiseOperation,
typename BElementwiseOperation, typename WeiElementwiseOperation,
typename CElementwiseOperation> typename OutElementwiseOperation>
struct DeviceGroupedConvFwd : public BaseOperator struct DeviceGroupedConvFwd : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, // input image MakeArgumentPointer(const void* p_in, // input image
const void* p_b, // weight const void* p_wei, // weight
void* p_c, // output image void* p_out, // output image
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides, const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides, const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations, const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads, const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads, const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op, const InElementwiseOperation& in_element_op,
const BElementwiseOperation& b_element_op, const WeiElementwiseOperation& wei_element_op,
const CElementwiseOperation& c_element_op) = 0; const OutElementwiseOperation& out_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -33,6 +33,8 @@ struct DeviceNormalization : public BaseOperator ...@@ -33,6 +33,8 @@ struct DeviceNormalization : public BaseOperator
const void* p_gamma, const void* p_gamma,
const void* p_beta, const void* p_beta,
void* p_y, void* p_y,
void* p_savedMean,
void* p_savedInvVar,
AccElementwiseOperation acc_elementwise_op) = 0; AccElementwiseOperation acc_elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck { namespace ck {
......
...@@ -150,7 +150,10 @@ template <typename ADataType, ...@@ -150,7 +150,10 @@ template <typename ADataType,
ck::index_t BBlockTransferDstScalarPerVector_K1, ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN, bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector> ck::index_t CThreadTransferDstScalarPerVector,
ck::index_t NumGemmKPrefetchStage = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -323,7 +326,10 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -323,7 +326,10 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
BBlockLdsAddExtraN, BBlockLdsAddExtraN,
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector,
NumGemmKPrefetchStage,
LoopSched,
PipelineVer>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
...@@ -622,6 +628,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -622,6 +628,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
{ {
auto str = std::stringstream(); auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off // clang-format off
str << "DeviceBatchedGemmXdl" str << "DeviceBatchedGemmXdl"
<< "<" << "<"
...@@ -629,7 +641,13 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -629,7 +641,13 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock
<< ">"; << ">"
<< " NumGemmKPrefetchStage: "
<< NumGemmKPrefetchStage << ", "
<< "LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -67,6 +67,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -67,6 +67,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation> OutElementwiseOperation>
{ {
static constexpr ck::index_t NDimSpatial = 2;
using DeviceOp = using DeviceOp =
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
...@@ -107,18 +109,18 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -107,18 +109,18 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock;
static constexpr auto BBlockLdsN1Padding = 4; static constexpr auto BBlockLdsN1Padding = 4;
static auto static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths, std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::array<ck::index_t, NDimSpatial> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k) ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -390,13 +392,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -390,13 +392,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths, std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::array<ck::index_t, NDimSpatial> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t M01, ck::index_t M01,
ck::index_t N01, ck::index_t N01,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
...@@ -473,11 +475,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -473,11 +475,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
index_t Conv_N_; index_t Conv_N_;
index_t Conv_K_; index_t Conv_K_;
index_t Conv_C_; index_t Conv_C_;
std::vector<index_t> output_spatial_lengths_; std::array<index_t, NDimSpatial> output_spatial_lengths_;
std::vector<index_t> filter_spatial_lengths_; std::array<index_t, NDimSpatial> filter_spatial_lengths_;
std::vector<index_t> conv_filter_strides_; std::array<index_t, NDimSpatial> conv_filter_strides_;
std::vector<index_t> input_left_pads_; std::array<index_t, NDimSpatial> input_left_pads_;
std::vector<index_t> input_right_pads_; std::array<index_t, NDimSpatial> input_right_pads_;
index_t k_batch_; index_t k_batch_;
}; };
...@@ -682,13 +684,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -682,13 +684,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths, std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::array<ck::index_t, NDimSpatial> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::array<ck::index_t, NDimSpatial> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -724,13 +726,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ...@@ -724,13 +726,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths, std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::array<ck::index_t, NDimSpatial> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::array<ck::index_t, NDimSpatial> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename AccDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t K0PerBlock,
ck::index_t K1,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector>
struct DeviceConvNdBwdDataNwcKxcNwk_Dl
: public DeviceConvBwdData<
NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWC,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NDHWC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC,
ck::tensor_layout::convolution::KYXC,
ck::tensor_layout::convolution::KZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK,
ck::tensor_layout::convolution::NHWK,
ck::tensor_layout::convolution::NDHWK>>,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
using DeviceOp = DeviceConvNdBwdDataNwcKxcNwk_Dl;
using ADataType = OutDataType;
using BDataType = WeiDataType;
using CDataType = InDataType;
// TODO make A/B datatype different
using ABDataType = InDataType;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> tildes)
{
using namespace ck;
index_t i_xtilde = tildes[0];
const index_t Wi = input_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[0];
const index_t X = filter_spatial_lengths[0];
const index_t InLeftPadW = input_left_pads[0];
const index_t InRightPadW = input_right_pads[0];
const index_t ConvStrideW = conv_filter_strides[0];
const index_t ConvDilationW = conv_filter_dilations[0];
const auto K0 = K / K1;
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)),
make_tuple(make_pass_through_transform(N * Wo),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: input tensor
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
else
{
const auto out_n_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Wo, K));
const auto wei_k_x_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, X, C));
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// A: output tensor
const auto out_n_wop_k_grid_desc = transform_tensor_descriptor(
out_n_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_wop_k_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor(
out_n_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
make_merge_transform(make_tuple(N, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// B weight tensor
const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_k_x_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor(
wei_k_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<>{}, Sequence<3>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// C: input tensor
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
in_n_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
in_n_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, WTildeSlice)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> tildes)
{
using namespace ck;
index_t i_ytilde = tildes[0];
index_t i_xtilde = tildes[1];
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
const index_t Y = filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const auto K0 = K / K1;
const auto out_n_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
const auto wei_k_y_x_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: input tensor
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
else
{
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// B weight tensor
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0, 1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<>{},
Sequence<>{},
Sequence<4>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// C: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_htildeslice_wtildeslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> tildes)
{
using namespace ck;
const index_t i_ztilde = tildes[0];
const index_t i_ytilde = tildes[1];
const index_t i_xtilde = tildes[2];
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
const index_t Z = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const auto K0 = K / K1;
const auto out_n_do_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K));
const auto wei_k_z_y_x_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C));
const auto in_n_di_hi_wi_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Do * Ho * Wo),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
// B: weight tensor
const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: input tensor
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<0, 2, 4, 6>{},
Sequence<7>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
else
{
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto DTilde =
Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IDTildeSliceEnd = math::min(
DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// A: output tensor
const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_n_do_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Do, I0, I0),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
transform_tensor_descriptor(
out_n_dop_hop_wop_k_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(ZDot, DTilde),
make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
make_embed_transform(make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7, 8>{}));
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// B weight tensor
const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
transform_tensor_descriptor(
wei_k_z_y_x_c_grid_desc,
make_tuple(
make_pass_through_transform(K),
make_embed_transform(make_tuple(ZDot, ZTilde),
make_tuple(ConvStrideD / GcdStrideDilationD, I1)),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ztilde),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2>{},
Sequence<4>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0, 1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<>{},
Sequence<>{},
Sequence<>{},
Sequence<5>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// C: input tensor
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_di_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor(
in_n_dip_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(ZTilde, DTilde),
make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ztilde),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<>{},
Sequence<3>{},
Sequence<4>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
make_tuple(
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc()
{
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1,
1,
1,
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
{0, 0, 0});
}
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
// GridwiseGemm
using GridwiseGemm =
GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType,
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
using BGridDesc_K0_N0_N1_K1 =
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using DefaultBlock2CTileMap =
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
// Argument
struct Argument : public BaseArgument
{
Argument(InDataType* p_in_grid,
const WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: p_a_grid_{p_out_grid},
p_b_grid_{p_wei_grid},
p_c_grid_{p_in_grid},
a_element_op_{out_element_op},
b_element_op_{wei_element_op},
c_element_op_{in_element_op},
Conv_N_{N},
Conv_K_{K},
Conv_C_{C},
input_spatial_lengths_{input_spatial_lengths},
filter_spatial_lengths_{filter_spatial_lengths},
output_spatial_lengths_{output_spatial_lengths},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
CreateABCDesc<NDimSpatial>();
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
void CreateABCDesc()
{
const index_t ConvStrideW = conv_filter_strides_[0];
const index_t ConvDilationW = conv_filter_dilations_[0];
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const index_t X = filter_spatial_lengths_[0];
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{
// check slice is valid
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
if(XDotSlice <= 0)
{
continue;
}
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_xtilde});
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2]))
{
a_grid_desc_k0_m0_m1_k1_container_.push_back(
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(descs[I0]));
b_grid_desc_k0_n0_n1_k1_container_.push_back(
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(descs[I1]));
c_grid_desc_m0_m10_m11_n0_n10_n11_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2]));
}
}
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
void CreateABCDesc()
{
const index_t ConvStrideH = conv_filter_strides_[0];
const index_t ConvStrideW = conv_filter_strides_[1];
const index_t ConvDilationH = conv_filter_dilations_[0];
const index_t ConvDilationW = conv_filter_dilations_[1];
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const index_t Y = filter_spatial_lengths_[0];
const index_t X = filter_spatial_lengths_[1];
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
{
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{
// check slice is valid
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
if(YDotSlice * XDotSlice <= 0)
{
continue;
}
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_ytilde, i_xtilde});
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2]))
{
a_grid_desc_k0_m0_m1_k1_container_.push_back(
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(descs[I0]));
b_grid_desc_k0_n0_n1_k1_container_.push_back(
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(descs[I1]));
c_grid_desc_m0_m10_m11_n0_n10_n11_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2]));
}
}
}
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
void CreateABCDesc()
{
const index_t ConvStrideD = conv_filter_strides_[0];
const index_t ConvStrideH = conv_filter_strides_[1];
const index_t ConvStrideW = conv_filter_strides_[2];
const index_t ConvDilationD = conv_filter_dilations_[0];
const index_t ConvDilationH = conv_filter_dilations_[1];
const index_t ConvDilationW = conv_filter_dilations_[2];
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const index_t Z = filter_spatial_lengths_[0];
const index_t Y = filter_spatial_lengths_[1];
const index_t X = filter_spatial_lengths_[2];
for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
{
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
{
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
{
// check slice is valid
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
if(ZDotSlice * YDotSlice * XDotSlice <= 0)
{
continue;
}
const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_,
Conv_K_,
Conv_C_,
input_spatial_lengths_,
filter_spatial_lengths_,
output_spatial_lengths_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_,
{i_ztilde, i_ytilde, i_xtilde});
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2]))
{
a_grid_desc_k0_m0_m1_k1_container_.push_back(
GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(descs[I0]));
b_grid_desc_k0_n0_n1_k1_container_.push_back(
GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(descs[I1]));
c_grid_desc_m0_m10_m11_n0_n10_n11_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2]));
}
}
}
}
}
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
std::vector<AGridDesc_K0_M0_M1_K1> a_grid_desc_k0_m0_m1_k1_container_;
std::vector<BGridDesc_K0_N0_N1_K1> b_grid_desc_k0_n0_n1_k1_container_;
std::vector<CGridDesc_M0_M10_M11_N0_N10_N11> c_grid_desc_m0_m10_m11_n0_n10_n11_container_;
std::vector<DefaultBlock2CTileMap> block_2_ctile_map_container_;
// element-wise op
OutElementwiseOperation a_element_op_;
WeiElementwiseOperation b_element_op_;
InElementwiseOperation c_element_op_;
// for checking IsSupportedArgument()
index_t Conv_N_;
index_t Conv_K_;
index_t Conv_C_;
std::vector<ck::index_t> input_spatial_lengths_;
std::vector<ck::index_t> filter_spatial_lengths_;
std::vector<ck::index_t> output_spatial_lengths_;
std::vector<ck::index_t> conv_filter_strides_;
std::vector<ck::index_t> conv_filter_dilations_;
std::vector<ck::index_t> input_left_pads_;
std::vector<ck::index_t> input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{
{
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_container_{"
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.c_grid_desc_m_n_container_{ "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
<< std::endl;
std::cout << "arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_( "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I0)
<< ", "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I1)
<< ", "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I2)
<< ", "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I3)
<< ", "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I4)
<< ", "
<< arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I5)
<< " ) " << std::endl;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
}
const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize(
arg.c_grid_desc_m_n_container_[i]);
auto launch_kernel = [&](auto has_main_k_block_loop,
auto has_double_tail_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr bool has_double_loop = has_double_tail_k_block_loop;
const auto kernel = kernel_gemm_dl_v1r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M0_M1_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N0_N1_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DeviceOp::DefaultBlock2CTileMap>,
has_main_loop,
has_double_loop>;
ave_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_container_[i],
arg.b_grid_desc_k0_n0_n1_k1_container_[i],
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i],
arg.block_2_ctile_map_container_[i]);
};
const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_container_[i].GetLength(I0);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
launch_kernel(integral_constant<bool, true>{}, integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
launch_kernel(integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
launch_kernel(integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
// check device
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030"))
{
return false;
}
if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NDimSpatial; i++)
{
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
{
return false;
}
}
}
// matrix A
{
auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1)
{
return false;
}
if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0)
{
return false;
}
const index_t K = arg.Conv_K_;
if(K % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0)
{
return false;
}
}
// matrix B
{
auto srcLoadLenghts = BBlockTransferThreadSliceLengths_K0_N0_N1_K1{};
auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{};
if(srcVectorLengths[I0] != 1 || srcVectorLengths[I3] != 1)
{
return false;
}
if(srcLoadLenghts[I1] % srcVectorLengths[I1] != 0 ||
srcLoadLenghts[I2] % srcVectorLengths[I2] != 0)
{
return false;
}
const index_t C = arg.Conv_K_;
if(C % (srcVectorLengths[I1] * srcVectorLengths[I2]) != 0)
{
return false;
}
}
// vector store C matrix into global memory
if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0))
{
std::cout << "Not surpport,because: arg.Conv_C_ % CThreadTransferDstScalarPerVector = "
<< arg.Conv_C_ % CThreadTransferDstScalarPerVector << std::endl;
return false;
}
// Gridwise GEMM size
for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]))
{
return false;
}
}
return true;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(InDataType* p_in_grid,
const WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
{
return Argument{p_in_grid,
p_wei_grid,
p_out_grid,
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(void* p_in_grid,
const void* p_wei_grid,
const void* p_out_grid,
ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) override
{
return std::make_unique<Argument>(static_cast<InDataType*>(p_in_grid),
static_cast<const WeiDataType*>(p_wei_grid),
static_cast<const OutDataType*>(p_out_grid),
N,
K,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceConvNdBwdDataNwcKxcNwk_Dl"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock
<< ">";
if constexpr(ConvBackwardDataSpecialization ==
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){
str<< " Filter1x1Stride1Pad0";
}
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
// X = Elementwise(input1, input2, input3, ...)
// Y = Normalization(X, beta, gamma)
namespace ck {
template <typename GridwiseElementwiseReduction,
typename InDataTypePointerTuple, // Datatype tuple of inputs
typename XDataType, // Datatype of X
typename GammaDataType, // Datatype of Gamma
typename BetaDataType, // Datatype of Beta
typename YDataType, // Datatype of Y
typename AccDataType, // AccDatatype
typename XElementwiseOperation, // Operation of input
typename YElementwiseOperation, // Operation of output of normalization
typename InGrid2dDescTuple, // Descriptor tuple of inputs
typename GridDesc_M_K> // Descriptor of inputs, Gamma, Beta
__global__ void kernel_elementwise_layernorm(
const InGrid2dDescTuple in_grid_2d_desc_tuple, // Descriptor tuple of inputs
const GridDesc_M_K x_grid_desc_m_k, // Descriptor of X
const GridDesc_M_K gamma_grid_desc_m_k, // Descriptor of gamma
const GridDesc_M_K beta_grid_desc_m_k, // Descriptor of beta
const GridDesc_M_K y_grid_desc_m_k, // Descriptor of Y
index_t num_k_block_tile_iteration, //
AccDataType epsilon, // Datatype of epsilon
const InDataTypePointerTuple p_in_global_tuple, // Ptr tuple of input matrixs
const GammaDataType* const __restrict__ p_gamma_global, // Ptr of gamma
const BetaDataType* const __restrict__ p_beta_global, // Ptr of beta
YDataType* const __restrict__ p_y_global, // Ptr of y
const XElementwiseOperation x_elementwise_op, // Operation of input
const YElementwiseOperation y_elementwise_op) // Operation of output of normalization
{
extern __shared__ XDataType p_x_lds[];
GridwiseElementwiseReduction::Run(in_grid_2d_desc_tuple, // Descriptor tuple of inputs
x_grid_desc_m_k, // Descriptor of X
gamma_grid_desc_m_k, // Descriptor of Gamma
beta_grid_desc_m_k, // Descriptor of Beta
y_grid_desc_m_k, // Descriptor of Y
num_k_block_tile_iteration, //
epsilon, // epsilon
p_in_global_tuple, // Ptr tuple of inputs
p_x_lds, // Ptr of X
p_gamma_global, // Ptr of gamma
p_beta_global, // Ptr of beta
p_y_global, // Ptr of Y
x_elementwise_op, // Operation of input
y_elementwise_op); // Operation of output of normalization
};
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
// Y = LayerNorm(A + B, Beta, Gamma)
template <typename InDataTypeTuple, // Datatype of inputs
typename GammaDataType, // Datatype of gamma
typename BetaDataType, // Datatype of beta
typename AccDataType, //
typename YDataType, //
typename XElementwiseOperation, //
typename YElementwiseOperation, //
index_t Rank, //
index_t NumReduceDim, //
index_t BlockSize, //
index_t MThreadClusterSize, // Num of threads in a block on M direction
index_t KThreadClusterSize, // Num of threads in a block on N direction
index_t MThreadSliceSize, // Each thread calculate rows
index_t KThreadSliceSize, // Each thread calculate columns
index_t XYSrcVectorDim, // Dimension to do reduce
index_t XSrcVectorSize, // Size to fetch source x
index_t GammaSrcVectorDim, // Dimension for gamma to do reduce
index_t GammaSrcVectorSize, // Size to fetch source gamma
index_t BetaSrcVectorDim, // Dimension for beta to do reduce
index_t BetaSrcVectorSize, // Size to fetch source beta
index_t YDstVectorSize> // Size to write destination Y
struct DeviceElementwiseNormalizationImpl
: public DeviceElementwiseNormalization<InDataTypeTuple,
GammaDataType,
BetaDataType,
AccDataType,
YDataType,
XElementwiseOperation,
YElementwiseOperation,
Rank,
NumReduceDim>
{
static constexpr int NumInput = InDataTypeTuple::Size();
using XDataType = YDataType;
static_assert(
(KThreadSliceSize % GammaSrcVectorSize == 0),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
static_assert(
(KThreadSliceSize % BetaSrcVectorSize == 0),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
static constexpr index_t M_BlockTileSize =
MThreadClusterSize * MThreadSliceSize; // num of rows calculated in a block
static constexpr index_t K_BlockTileSize =
KThreadClusterSize * KThreadSliceSize; // num of columns calculated in a block
static auto GenerateInDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(nullptr);
},
Number<NumInput>{});
};
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides,
int blkGroupSize,
int numBlockTileIteration)
{
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t numSrcDim = Rank;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
template <index_t TupleSize>
static auto GenerateSrcGrid2dDescTuple(Number<TupleSize>)
{
return generate_tuple([&](auto) { return MakeSrc2dDescriptor({1}, {1}, 1, 1); },
Number<TupleSize>{});
};
using InGrid2dDescTuple = decltype(GenerateSrcGrid2dDescTuple(Number<NumInput>{}));
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseReduceLayernormGeneric =
GridwiseElementwiseLayernormWelfordVariance_mk_to_mk<InDataTypePointerTuple,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
XElementwiseOperation,
YElementwiseOperation,
InGrid2dDescTuple,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XYSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
XYSrcVectorDim,
YDstVectorSize,
false>;
using GridwiseReduceLayernormSweepOnce =
GridwiseElementwiseLayernormWelfordVariance_mk_to_mk<InDataTypePointerTuple,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
XElementwiseOperation,
YElementwiseOperation,
InGrid2dDescTuple,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
XYSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
XYSrcVectorDim,
YDstVectorSize,
true>;
struct Argument : public BaseArgument
{
Argument(const std::vector<index_t> lengths,
const std::array<std::vector<index_t>, NumInput> inStridesArray,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims,
XElementwiseOperation x_elementwise_op,
YElementwiseOperation y_elementwise_op,
AccDataType epsilon,
const std::array<const void*, NumInput> in_dev_buffers,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
YDataType* p_y)
: epsilon_(epsilon),
p_gamma_(p_gamma),
p_beta_(p_beta),
p_y_(p_y),
x_elementwise_op_(x_elementwise_op),
y_elementwise_op_(y_elementwise_op)
{
Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
for(int i = 0; i < NumInput; i++)
{
inStridesArray_[i] =
shuffle_tensor_dimensions<Rank, NumReduceDim>(inStridesArray[i], reduceDims);
}
yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims);
in_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(in_dev_buffers[I.value]);
},
Number<NumInput>{});
long_index_t invariant_total_length;
long_index_t reduce_total_length;
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(Lengths_);
blkGroupSize_ = 1;
numBlockTileIteration_ = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize_;
in_grid_2d_desc_tuple_ = generate_tuple(
[&](auto I) {
return MakeSrc2dDescriptor(
Lengths_, inStridesArray_[I.value], blkGroupSize_, numBlockTileIteration_);
},
Number<NumInput>{});
x_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, xStrides_, blkGroupSize_, numBlockTileIteration_);
gamma_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, gammaStrides_, blkGroupSize_, numBlockTileIteration_);
beta_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, betaStrides_, blkGroupSize_, numBlockTileIteration_);
y_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, yStrides_, blkGroupSize_, numBlockTileIteration_);
sweep_once_ =
x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
if(!sweep_once_) // if not sweep once, compute memory size for matrix X in lds for
// store Intermediate results
{
int block_TileSize = M_BlockTileSize * reduce_total_length;
x_lds_size_ = block_TileSize * sizeof(XDataType);
}
else
x_lds_size_ = 0;
}
AccDataType epsilon_;
InDataTypePointerTuple in_dev_buffers_;
const GammaDataType* p_gamma_;
const BetaDataType* p_beta_;
YDataType* p_y_;
std::vector<index_t> Lengths_;
std::array<std::vector<index_t>, NumInput> inStridesArray_;
std::vector<index_t> xStrides_;
std::vector<index_t> gammaStrides_;
std::vector<index_t> betaStrides_;
std::vector<index_t> yStrides_;
XElementwiseOperation x_elementwise_op_;
YElementwiseOperation y_elementwise_op_;
int blkGroupSize_;
int numBlockTileIteration_;
size_t gridSize_;
InGrid2dDescTuple in_grid_2d_desc_tuple_;
GridDesc_M_K x_grid_desc_m_k_;
GridDesc_M_K gamma_grid_desc_m_k_;
GridDesc_M_K beta_grid_desc_m_k_;
GridDesc_M_K y_grid_desc_m_k_;
bool sweep_once_;
int x_lds_size_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel_main =
arg.sweep_once_ ? kernel_elementwise_layernorm<GridwiseReduceLayernormSweepOnce,
InDataTypePointerTuple,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
XElementwiseOperation,
YElementwiseOperation,
InGrid2dDescTuple,
GridDesc_M_K>
: kernel_elementwise_layernorm<GridwiseReduceLayernormGeneric,
InDataTypePointerTuple,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
XElementwiseOperation,
YElementwiseOperation,
InGrid2dDescTuple,
GridDesc_M_K>;
float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config,
kernel_main,
dim3(arg.gridSize_),
dim3(BlockSize),
arg.x_lds_size_,
arg.in_grid_2d_desc_tuple_,
arg.x_grid_desc_m_k_,
arg.gamma_grid_desc_m_k_,
arg.beta_grid_desc_m_k_,
arg.y_grid_desc_m_k_,
arg.numBlockTileIteration_,
arg.epsilon_,
arg.in_dev_buffers_,
arg.p_gamma_,
arg.p_beta_,
arg.p_y_,
arg.x_elementwise_op_,
arg.y_elementwise_op_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
constexpr index_t NumInvariantDim = Rank - NumReduceDim;
if constexpr(XYSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
{
return false;
}
else
{
for(int i = 0; i < NumInput; i++)
{
if(p_arg_->inStridesArray_[i][NumInvariantDim - 1] != 1)
return false;
}
if(p_arg_->inStridesArray_[0][NumInvariantDim - 1] != 1 &&
p_arg_->inStridesArray_[1][NumInvariantDim - 1] != 1)
return false;
if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0)
return false;
};
}
else
{
for(int i = 0; i < NumInput; i++)
{
if(p_arg_->inStridesArray_[i][Rank - 1] != 1)
return false;
}
if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0)
return false;
};
if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0)
{
return false;
}
auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
bool ret = true;
if(!isLastDimensionCoalesced)
ret = scalarPerVector == 1;
else
ret = KThreadSliceSize % scalarPerVector == 0;
return ret;
};
if(!IsScalarPerVectorValid(p_arg_->gammaStrides_.back() == 1, GammaSrcVectorSize))
return false;
if(!IsScalarPerVectorValid(p_arg_->betaStrides_.back() == 1, BetaSrcVectorSize))
return false;
// if fastest dim is not reduced
if constexpr(XYSrcVectorDim == 0) //
{
if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1)
return (false);
if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0)
return (false);
}
else // if fastest dim is reduced
{
if(p_arg_->gammaStrides_[Rank - 1] != 1)
return (false);
if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0)
return (false);
}
// if fastest dim is not reduced
if constexpr(XYSrcVectorDim == 0)
{
if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1)
return (false);
if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0)
return (false);
}
else // if fastest dim is reduced
{
if(p_arg_->betaStrides_[Rank - 1] != 1)
return (false);
if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0)
return (false);
}
return true;
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::vector<index_t> lengths,
const std::array<std::vector<index_t>, NumInput> inStridesArray,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> betaStrides,
const std::vector<index_t> yStrides,
const std::vector<index_t> reduceDims,
AccDataType epsilon,
const std::array<const void*, NumInput> in_dev_buffers,
const void* p_gamma,
const void* p_beta,
void* p_y,
XElementwiseOperation x_elementwise_op,
YElementwiseOperation y_elementwise_op) override
{
return std::make_unique<Argument>(lengths,
inStridesArray,
gammaStrides,
betaStrides,
yStrides,
reduceDims,
x_elementwise_op,
y_elementwise_op,
epsilon,
in_dev_buffers,
static_cast<const GammaDataType*>(p_gamma),
static_cast<const BetaDataType*>(p_beta),
static_cast<YDataType*>(p_y));
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceElementwiseNormalizationImpl<" << BlockSize << ",";
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
str << "XYSrcVectorDim_" << XYSrcVectorDim << ",";
str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -141,7 +141,8 @@ template <typename ALayout, ...@@ -141,7 +141,8 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
...@@ -282,7 +283,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -282,7 +283,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched,
PipelineVer>;
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
...@@ -664,6 +666,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -664,6 +666,12 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{ {
auto str = std::stringstream(); auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off // clang-format off
str << "DeviceGemmMultipleD_Xdl_CShuffle" str << "DeviceGemmMultipleD_Xdl_CShuffle"
<< "<" << "<"
...@@ -674,7 +682,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -674,7 +682,11 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
<< AK1 << ", " << AK1 << ", "
<< BK1 << ", " << BK1 << ", "
<< getGemmSpecializationString(GemmSpec) << getGemmSpecializationString(GemmSpec)
<< ">"; << ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -56,7 +56,9 @@ template <typename ADataType, ...@@ -56,7 +56,9 @@ template <typename ADataType,
bool BBlockLdsAddExtraN, bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim, ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector, ck::index_t CThreadTransferDstScalarPerVector,
ck::index_t NumPrefetch = 1> ck::index_t NumPrefetch = 1,
ck::LoopScheduler LoopSched = make_default_loop_scheduler(),
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGemmXdl : public DeviceGemm<ALayout, struct DeviceGemmXdl : public DeviceGemm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -230,7 +232,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -230,7 +232,9 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
NumPrefetch>; NumPrefetch,
LoopSched,
PipelineVer>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -523,6 +527,12 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -523,6 +527,12 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
{ {
auto str = std::stringstream(); auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off // clang-format off
str << "DeviceGemmXdl" str << "DeviceGemmXdl"
<< "<" << "<"
...@@ -535,7 +545,13 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -535,7 +545,13 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
<< NPerXDL << ", " << NPerXDL << ", "
<< MXdlPerWave << ", " << MXdlPerWave << ", "
<< NXdlPerWave << NXdlPerWave
<< ">"; << ">"
<< " NumPrefetch: "
<< NumPrefetch << ", "
<< "LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -64,7 +64,8 @@ template <typename ALayout, ...@@ -64,7 +64,8 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -393,7 +394,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -393,7 +394,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched,
PipelineVer>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -656,6 +658,12 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -656,6 +658,12 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{ {
auto str = std::stringstream(); auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off // clang-format off
str << "DeviceGemm_Xdl_CShuffle" str << "DeviceGemm_Xdl_CShuffle"
<< "<" << "<"
...@@ -665,7 +673,11 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -665,7 +673,11 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << BK1
<< ">"; << ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];;
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -4,13 +4,14 @@ ...@@ -4,13 +4,14 @@
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <numeric>
#include <sstream> #include <sstream>
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
...@@ -20,6 +21,108 @@ namespace ck { ...@@ -20,6 +21,108 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace {
struct ComputePtrOffsetOfStridedBatch
{
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
};
} // namespace
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename AGridDesc_B_K0_M_K1,
typename BGridDesc_B_K0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_xdlops_bwd_weight(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op,
const index_t batch_count,
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
p_b_grid + b_batch_offset,
p_c_grid + c_batch_offset,
p_shared,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_b_k0_m_k1_grid_desc;
ignore = b_b_k0_n_k1_grid_desc;
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = batch_count;
ignore = block_2_ctile_map;
ignore = compute_ptr_offset_of_batch;
compute_ptr_offset_of_batch.GetAPtrOffset(0);
compute_ptr_offset_of_batch.GetBPtrOffset(0);
compute_ptr_offset_of_batch.GetCPtrOffset(0);
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] // out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InDataType, typename InDataType,
...@@ -57,21 +160,21 @@ template <ck::index_t NDimSpatial, ...@@ -57,21 +160,21 @@ template <ck::index_t NDimSpatial,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl> index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
: public DeviceConvBwdWeight< : public DeviceGroupedConvBwdWeight<
NDimSpatial, NDimSpatial,
ck::tuple_element_t<NDimSpatial - 1, ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWC, ck::Tuple<ck::tensor_layout::convolution::GNWC,
ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::NDHWC>>, ck::tensor_layout::convolution::GNDHWC>>,
ck::tuple_element_t<NDimSpatial - 1, ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::KXC, ck::Tuple<ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::KZYXC>>, ck::tensor_layout::convolution::GKZYXC>>,
ck::tuple_element_t<NDimSpatial - 1, ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::NWK, ck::Tuple<ck::tensor_layout::convolution::GNWK,
ck::tensor_layout::convolution::NHWK, ck::tensor_layout::convolution::GNHWK,
ck::tensor_layout::convolution::NDHWK>>, ck::tensor_layout::convolution::GNDHWK>>,
InDataType, InDataType,
WeiDataType, WeiDataType,
OutDataType, OutDataType,
...@@ -79,7 +182,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -79,7 +182,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation> OutElementwiseOperation>
{ {
using DeviceOp = DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle; using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle;
using ADataType = OutDataType; using ADataType = OutDataType;
using BDataType = InDataType; using BDataType = InDataType;
...@@ -117,18 +220,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -117,18 +220,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
static constexpr auto BBlockLdsN1Padding = 4; static constexpr auto BBlockLdsN1Padding = 4;
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths, std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::array<ck::index_t, NDimSpatial> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k) ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -269,18 +372,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -269,18 +372,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths, std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::array<ck::index_t, NDimSpatial> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k) ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -436,18 +539,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -436,18 +539,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths, std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::array<ck::index_t, NDimSpatial> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t batch_k) ck::index_t batch_k)
{ {
using namespace ck; using namespace ck;
...@@ -664,8 +767,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -664,8 +767,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
} }
template <index_t Dim> template <index_t Dim>
static auto MakeDescriptor_M0(const std::vector<index_t>& shape, static auto MakeDescriptor_M0(const std::array<index_t, Dim>& shape,
const std::vector<index_t>& stride, const std::array<index_t, Dim>& stride,
index_t gridSize, index_t gridSize,
index_t blockSize) index_t blockSize)
{ {
...@@ -759,16 +862,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -759,16 +862,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
Argument(const InDataType* p_in_grid, Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
ck::index_t G,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths, std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::array<ck::index_t, NDimSpatial> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::array<ck::index_t, NDimSpatial> input_right_pads,
ck::index_t M01, ck::index_t M01,
ck::index_t N01, ck::index_t N01,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
...@@ -783,11 +887,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -783,11 +887,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
c_grid_desc_m_n_{}, c_grid_desc_m_n_{},
c_grid_desc_mblock_mperblock_nblock_nperblock_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_ctile_map_{}, block_2_ctile_map_{},
compute_ptr_offset_of_batch_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
a_element_op_{out_element_op}, a_element_op_{out_element_op},
b_element_op_{in_element_op}, b_element_op_{in_element_op},
c_element_op_{wei_element_op}, c_element_op_{wei_element_op},
Conv_G_{G},
Conv_N_{N}, Conv_N_{N},
Conv_K_{K}, Conv_K_{K},
Conv_C_{C}, Conv_C_{C},
...@@ -819,6 +925,26 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -819,6 +925,26 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
block_2_ctile_map_ = block_2_ctile_map_ =
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ =
N * K *
std::accumulate(begin(output_spatial_lengths),
end(output_spatial_lengths),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideB_ =
N * C *
std::accumulate(begin(input_spatial_lengths),
end(input_spatial_lengths),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ =
K * C *
std::accumulate(begin(filter_spatial_lengths),
end(filter_spatial_lengths),
index_t{1},
std::multiplies<>{});
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
b_grid_desc_kbatch_k0_n_k1_, b_grid_desc_kbatch_k0_n_k1_,
c_grid_desc_m_n_, c_grid_desc_m_n_,
...@@ -836,21 +962,29 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -836,21 +962,29 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
InElementwiseOperation a_element_op_; InElementwiseOperation a_element_op_;
OutElementwiseOperation b_element_op_; OutElementwiseOperation b_element_op_;
WeiElementwiseOperation c_element_op_; WeiElementwiseOperation c_element_op_;
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
index_t Conv_G_;
index_t Conv_N_; index_t Conv_N_;
index_t Conv_K_; index_t Conv_K_;
index_t Conv_C_; index_t Conv_C_;
std::vector<index_t> output_spatial_lengths_; std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
std::vector<index_t> filter_spatial_lengths_; std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
std::vector<index_t> conv_filter_strides_; std::array<ck::index_t, NDimSpatial> conv_filter_strides_;
std::vector<index_t> input_left_pads_; std::array<ck::index_t, NDimSpatial> input_left_pads_;
std::vector<index_t> input_right_pads_; std::array<ck::index_t, NDimSpatial> input_right_pads_;
index_t k_batch_; index_t k_batch_;
}; };
...@@ -873,14 +1007,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -873,14 +1007,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
ShowInfo(arg);
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
...@@ -891,7 +1023,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -891,7 +1023,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
} }
const index_t grid_size = const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_;
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
...@@ -900,17 +1032,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -900,17 +1032,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_gemm_xdlops_bwd_weight< const auto kernel = kernel_batched_gemm_xdlops_bwd_weight<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
OutElementwiseOperation, OutElementwiseOperation,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
...@@ -921,13 +1054,15 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -921,13 +1054,15 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_, arg.c_element_op_,
arg.block_2_ctile_map_); arg.Conv_G_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_,
arg.compute_ptr_offset_of_batch_);
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
...@@ -998,16 +1133,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -998,16 +1133,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
static auto MakeArgument(const InDataType* p_in_grid, static auto MakeArgument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
ck::index_t G,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths, std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::array<ck::index_t, NDimSpatial> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::array<ck::index_t, NDimSpatial> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -1016,6 +1152,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -1016,6 +1152,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
return Argument{p_in_grid, return Argument{p_in_grid,
p_wei_grid, p_wei_grid,
p_out_grid, p_out_grid,
G,
N, N,
K, K,
C, C,
...@@ -1040,16 +1177,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -1040,16 +1177,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
MakeArgumentPointer(const void* p_in_grid, MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid, void* p_wei_grid,
const void* p_out_grid, const void* p_out_grid,
ck::index_t G,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t C, ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths, std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::array<ck::index_t, NDimSpatial> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::array<ck::index_t, NDimSpatial> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::array<ck::index_t, NDimSpatial> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -1058,6 +1196,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -1058,6 +1196,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid), return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<WeiDataType*>(p_wei_grid), static_cast<WeiDataType*>(p_wei_grid),
static_cast<const OutDataType*>(p_out_grid), static_cast<const OutDataType*>(p_out_grid),
G,
N, N,
K, K,
C, C,
...@@ -1086,7 +1225,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle ...@@ -1086,7 +1225,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle" str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp" #include "ck/host_utility/io.hpp"
#include "ck/library/utility/numeric.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -410,10 +411,9 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -410,10 +411,9 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
{ {
const index_t N = r_g_n_wos_lengths[1]; const index_t N = r_g_n_wos_lengths[1];
const index_t NHoWo = N * std::accumulate(r_g_n_wos_lengths.begin() + 2, const index_t NHoWo =
r_g_n_wos_lengths.begin() + 2 + NDimSpatial, N * ck::accumulate_n<index_t>(
index_t{1}, r_g_n_wos_lengths.begin() + 2, NDimSpatial, 1, std::multiplies<>());
std::multiplies<index_t>());
const auto r_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(NHoWo)); const auto r_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(NHoWo));
...@@ -435,10 +435,9 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -435,10 +435,9 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
const index_t WoStride = r_g_n_wos_strides[NDimSpatial + 2]; const index_t WoStride = r_g_n_wos_strides[NDimSpatial + 2];
const index_t NHoWo = N * std::accumulate(r_g_n_wos_lengths.begin() + 2, const index_t NHoWo =
r_g_n_wos_lengths.begin() + 2 + NDimSpatial, N * ck::accumulate_n<index_t>(
index_t{1}, r_g_n_wos_lengths.begin() + 2, NDimSpatial, 1, std::multiplies<>());
std::multiplies<index_t>());
const auto r_grid_desc_mraw = const auto r_grid_desc_mraw =
make_naive_tensor_descriptor(make_tuple(NHoWo), make_tuple(WoStride)); make_naive_tensor_descriptor(make_tuple(NHoWo), make_tuple(WoStride));
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -24,17 +24,17 @@ template <typename GridwiseReduction, ...@@ -24,17 +24,17 @@ template <typename GridwiseReduction,
typename AccDataType, typename AccDataType,
typename AccElementwiseOperation, typename AccElementwiseOperation,
typename GridDesc_M_K> typename GridDesc_M_K>
__global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k, __global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K beta_grid_desc_m_k, const GridDesc_M_K beta_grid_desc_m_k,
const GridDesc_M_K y_grid_desc_m_k, const GridDesc_M_K y_grid_desc_m_k,
index_t num_k_block_tile_iteration, index_t num_k_block_tile_iteration,
AccDataType epsilon, AccDataType epsilon,
const XDataType* const __restrict__ p_x_global, const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global, const GammaDataType* const __restrict__ p_gamma_global,
const BetaDataType* const __restrict__ p_beta_global, const BetaDataType* const __restrict__ p_beta_global,
YDataType* const __restrict__ p_y_global, YDataType* const __restrict__ p_y_global,
const AccElementwiseOperation acc_elementwise_op) const AccElementwiseOperation acc_elementwise_op)
{ {
GridwiseReduction::Run(x_grid_desc_m_k, GridwiseReduction::Run(x_grid_desc_m_k,
gamma_grid_desc_m_k, gamma_grid_desc_m_k,
...@@ -54,7 +54,7 @@ namespace ck { ...@@ -54,7 +54,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
// Y = LayerNorm(X, Beta, Gamma) // Y = Normalization(X, Beta, Gamma)
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
...@@ -168,49 +168,49 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -168,49 +168,49 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
using GridwiseReduceLayernormGeneric = using GridwiseReduceLayernormGeneric =
GridwiseLayernormWelfordVariance_mk_to_mk<XDataType, GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccElementwiseOperation, AccElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
MThreadSliceSize, MThreadSliceSize,
KThreadSliceSize, KThreadSliceSize,
XYSrcVectorDim, XYSrcVectorDim,
XSrcVectorSize, XSrcVectorSize,
GammaSrcVectorDim, GammaSrcVectorDim,
GammaSrcVectorSize, GammaSrcVectorSize,
BetaSrcVectorDim, BetaSrcVectorDim,
BetaSrcVectorSize, BetaSrcVectorSize,
XYSrcVectorDim, XYSrcVectorDim,
YDstVectorSize, YDstVectorSize,
false>; false>;
using GridwiseReduceLayernormSweepOnce = using GridwiseNormalizationSweepOnce =
GridwiseLayernormWelfordVariance_mk_to_mk<XDataType, GridwiseNormalizationWelfordVariance_mk_to_mk<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccElementwiseOperation, AccElementwiseOperation,
GridDesc_M_K, GridDesc_M_K,
BlockSize, BlockSize,
MThreadClusterSize, MThreadClusterSize,
KThreadClusterSize, KThreadClusterSize,
MThreadSliceSize, MThreadSliceSize,
KThreadSliceSize, KThreadSliceSize,
XYSrcVectorDim, XYSrcVectorDim,
XSrcVectorSize, XSrcVectorSize,
GammaSrcVectorDim, GammaSrcVectorDim,
GammaSrcVectorSize, GammaSrcVectorSize,
BetaSrcVectorDim, BetaSrcVectorDim,
BetaSrcVectorSize, BetaSrcVectorSize,
XYSrcVectorDim, XYSrcVectorDim,
YDstVectorSize, YDstVectorSize,
true>; true>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -295,22 +295,22 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -295,22 +295,22 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto kernel_main = arg.isSweeponce_ const auto kernel_main = arg.isSweeponce_
? kernel_layernorm<GridwiseReduceLayernormSweepOnce, ? kernel_normalization<GridwiseNormalizationSweepOnce,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccElementwiseOperation, AccElementwiseOperation,
GridDesc_M_K> GridDesc_M_K>
: kernel_layernorm<GridwiseReduceLayernormGeneric, : kernel_normalization<GridwiseReduceLayernormGeneric,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
YDataType, YDataType,
AccDataType, AccDataType,
AccElementwiseOperation, AccElementwiseOperation,
GridDesc_M_K>; GridDesc_M_K>;
float avg_time = 0; float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
...@@ -426,8 +426,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -426,8 +426,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const void* p_gamma, const void* p_gamma,
const void* p_beta, const void* p_beta,
void* p_y, void* p_y,
void* p_saveMean,
void* p_saveInvVar,
AccElementwiseOperation acc_elementwise_op) override AccElementwiseOperation acc_elementwise_op) override
{ {
// TODO
// Optional cache of the intermediate results (mean and InvVariance) during the
// forward pass could speedup in the backward
ignore = p_saveMean;
ignore = p_saveInvVar;
return std::make_unique<Argument>(lengths, return std::make_unique<Argument>(lengths,
xStrides, xStrides,
gammaStrides, gammaStrides,
......
...@@ -226,6 +226,30 @@ struct DeviceReduceMultiBlock ...@@ -226,6 +226,30 @@ struct DeviceReduceMultiBlock
in_elementwise_op_{in_elementwise_op}, in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op} acc_elementwise_op_{acc_elementwise_op}
{ {
if(Rank != inLengths.size() || Rank != inStrides.size() ||
NumReduceDim != reduceDims.size())
{
throw std::runtime_error(
"One of inLengths/inStrides/reduceDims has invalid size!"
"\nExpected size inLengths: " +
std::to_string(Rank) + ", inStrides: " + std::to_string(Rank) +
", reduceDims: " + std::to_string(NumReduceDim) +
"\nBut have inLengths: " + std::to_string(inLengths.size()) +
", inStrides: " + std::to_string(inStrides.size()) +
", reduceDims: " + std::to_string(reduceDims.size()));
}
for(std::size_t i = 0; i < reduceDims.size(); ++i)
{
if(reduceDims[i] < 0 || reduceDims[i] >= Rank)
{
throw std::runtime_error("Provided reduce dimension exceed input tensor Rank!"
"\nHave reduceDims[" +
std::to_string(i) +
"]: " + std::to_string(reduceDims[i]));
}
}
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims); inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
......
...@@ -40,8 +40,9 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -40,8 +40,9 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
AccElementwiseOp, AccElementwiseOp,
Rank> Rank>
{ {
static constexpr index_t kRank = Rank; static constexpr index_t kRank = Rank;
static constexpr index_t kNumReduceDim = NumReduceDim; static constexpr index_t kNumReduceDim = NumReduceDim;
static constexpr index_t kNumInvariantDim = Rank - NumReduceDim;
virtual index_t GetRank() const override { return kRank; } virtual index_t GetRank() const override { return kRank; }
...@@ -168,6 +169,30 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -168,6 +169,30 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
in_elementwise_op_{in_elementwise_op}, in_elementwise_op_{in_elementwise_op},
acc_elementwise_op_{acc_elementwise_op} acc_elementwise_op_{acc_elementwise_op}
{ {
if(Rank != inLengths.size() || Rank != inStrides.size() ||
NumReduceDim != reduceDims.size())
{
throw std::runtime_error(
"One of inLengths/inStrides/reduceDims has invalid size!"
"\nExpected size inLengths: " +
std::to_string(Rank) + ", inStrides: " + std::to_string(Rank) +
", reduceDims: " + std::to_string(NumReduceDim) +
"\nBut have inLengths: " + std::to_string(inLengths.size()) +
", inStrides: " + std::to_string(inStrides.size()) +
", reduceDims: " + std::to_string(reduceDims.size()));
}
for(std::size_t i = 0; i < reduceDims.size(); ++i)
{
if(reduceDims[i] < 0 || reduceDims[i] >= Rank)
{
throw std::runtime_error("Provided reduce dimension exceed input tensor Rank!"
"\nHave reduceDims[" +
std::to_string(i) +
"]: " + std::to_string(reduceDims[i]));
}
}
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims); inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims); inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
...@@ -257,40 +282,78 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -257,40 +282,78 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
}; };
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override static bool IsSupportedArgument(const Argument& arg)
{ {
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
if constexpr(InSrcVectorDim == 0) if constexpr(InSrcVectorDim == 0)
{ {
if constexpr(NumInvariantDim == 0) if constexpr(kNumInvariantDim == 0)
{ {
return false; return false;
} }
else else
{ {
if(p_arg_->inStrides_[NumInvariantDim - 1] != 1) if(arg.inStrides_[kNumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
{
return false; return false;
}
if(p_arg_->invariant_lowest_length_ % InSrcVectorSize != 0) if(arg.invariant_lowest_length_ % InSrcVectorSize != 0)
{
return false; return false;
}; }
}
} }
else else
{ {
if(p_arg_->inStrides_[Rank - 1] != 1) if(arg.inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1)
{
return false; return false;
}
if(p_arg_->inLengths_[Rank - 1] % InSrcVectorSize != 0) if(arg.inLengths_[Rank - 1] % InSrcVectorSize != 0)
{
return false; return false;
}; }
}
// To improve
if(kNumInvariantDim > 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0)
{
return false;
}
if(p_arg_->invariant_lowest_length_ % OutDstVectorSize != 0) if(arg.inLengths_[Rank - 1] % OutDstVectorSize != 0)
{
return false; return false;
}
return true; return true;
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const std::vector<index_t> inLengths,
const std::vector<index_t> inStrides,
const std::vector<int> reduceDims,
const AccDataType alpha,
const AccDataType beta,
const InDataType* in_dev,
OutDataType* out_dev,
InElementwiseOp in_elementwise_op,
AccElementwiseOp acc_elementwise_op)
{
return Argument{inLengths,
inStrides,
reduceDims,
alpha,
beta,
in_dev,
out_dev,
in_elementwise_op,
acc_elementwise_op};
};
// //
// @brief Makes a pointer to Argument class. // @brief Makes a pointer to Argument class.
// //
...@@ -330,6 +393,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -330,6 +393,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
acc_elementwise_op); acc_elementwise_op);
}; };
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{ {
return std::make_unique<Invoker>(); return std::make_unique<Invoker>();
...@@ -340,10 +405,13 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType, ...@@ -340,10 +405,13 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceReduceSoftmax<" << BlockSize << ","; str << "DeviceReduceSoftmax<"
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; << Rank << "," << NumReduceDim << "," << BlockSize << ","
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","
<< "InSrcVectorDim_" << InSrcVectorDim
<< "_InSrcVectorSize_" << InSrcVectorSize
<< "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck/utility/math_v2.hpp" #include "ck/utility/math_v2.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/quantization_operation.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
......
#pragma once
#include "ck/utility/data_type.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
template <typename Activation>
struct Activation_Mul_Clamp
{
Activation_Mul_Clamp(float multiplier, Activation activationOp)
: multiplier_(multiplier), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const
{
float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32);
float y_fp32 = math::clamp(multiplier_ * x_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
__host__ __device__ constexpr void operator()(float& y, const int32_t& x) const
{
// We might type_convert to int8 after lambda in someplace
float x_fp32 = ck::type_convert<float>(x);
activationOp_(x_fp32, x_fp32);
y = math::clamp(multiplier_ * x_fp32, -128.f, 127.f);
}
float multiplier_;
Activation activationOp_;
};
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
template <typename Activation>
struct Add_Activation_Mul_Clamp
{
Add_Activation_Mul_Clamp(float multiplier, Activation activationOp)
: multiplier_(multiplier), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const
{
float y_fp32 = ck::type_convert<float>(x1 + x2);
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(multiplier_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float multiplier_;
Activation activationOp_;
};
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
template <typename Activation>
struct Add_Mul_Activation_Mul_Clamp
{
Add_Mul_Activation_Mul_Clamp(float multiplier1, float multiplier2, Activation activationOp)
: multiplier1_(multiplier1), multiplier2_(multiplier2), activationOp_(activationOp)
{
}
__host__ __device__ constexpr void
operator()(int8_t& y, const int32_t& x1, const int32_t& x2) const
{
float y_fp32 = ck::type_convert<float>(x1 + x2);
y_fp32 = multiplier1_ * y_fp32;
activationOp_(y_fp32, y_fp32);
y_fp32 = math::clamp(multiplier2_ * y_fp32, -128.f, 127.f);
y = ck::type_convert<int8_t>(y_fp32);
}
float multiplier1_;
float multiplier2_;
Activation activationOp_;
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp" #include "ck/utility/math_v2.hpp"
namespace ck { namespace ck {
......
...@@ -364,14 +364,16 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 ...@@ -364,14 +364,16 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
index_t M01 = 1, index_t M01 = 1,
index_t N01 = 1, index_t N01 = 1,
index_t KSplit = 1) index_t KSplit = 1)
: M01_(M01), : c_grid_desc_m_n_(c_grid_desc_m_n),
M01_(M01),
N01_(N01), N01_(N01),
KSplit_(KSplit), KSplit_(KSplit),
underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit)) underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit))
{ {
} }
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const __host__ __device__ constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{ {
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock);
...@@ -387,7 +389,10 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 ...@@ -387,7 +389,10 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
template <typename TopIdx> template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{ {
return underlying_map_.CalculateBottomIndex(idx_top); static_assert(TopIdx::Size() == 1);
return underlying_map_.CalculateBottomIndex(
make_multi_index(idx_top[I0] % CalculateGridSize()));
} }
template <typename CTileIdx, typename CTileDim> template <typename CTileIdx, typename CTileDim>
...@@ -418,6 +423,11 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 ...@@ -418,6 +423,11 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
} }
private: private:
__device__ constexpr index_t CalculateGridSize() const
{
return CalculateGridSize(c_grid_desc_m_n_);
}
__host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, __host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01, index_t M01,
index_t N01, index_t N01,
...@@ -450,6 +460,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 ...@@ -450,6 +460,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor; return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor;
} }
CGridDesc_M_N c_grid_desc_m_n_;
index_t M01_, N01_, KSplit_; index_t M01_, N01_, KSplit_;
using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1)); using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1));
UnderlyingMap underlying_map_; UnderlyingMap underlying_map_;
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
...@@ -74,7 +74,8 @@ template <typename FloatAB, ...@@ -74,7 +74,8 @@ template <typename FloatAB,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched> LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedGemmGemm_Xdl_CShuffle struct GridwiseBatchedGemmGemm_Xdl_CShuffle
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
...@@ -101,7 +102,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -101,7 +102,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>; using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
template <typename ABlockDesc_AK0_M_AK1> template <typename ABlockDesc_AK0_M_AK1>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -486,8 +488,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -486,8 +488,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// gridwise GEMM pipeline // gridwise GEMM pipeline
// Only supports LoopScheduler::Default // Only supports LoopScheduler::Default
const auto gridwise_gemm_pipeline = const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector<PipelineVer,
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopScheduler::Default>(); NumGemmKPrefetchStage,
LoopScheduler::Default>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment