Unverified Commit e1a5137e authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into transpose_5d

parents eb57178d 718065eb
...@@ -510,12 +510,15 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -510,12 +510,15 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using A0GridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using A0GridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(A0GridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(
using B0GridDesc_BK0_N_BK1 = remove_cvref_t<decltype( A0GridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(B0GridDesc_N_K{}))>; using B0GridDesc_BK0_N_BK1 =
using B1GridDesc_BK0_N_BK1 = remove_cvref_t<decltype( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(
GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(B1GridDesc_N_K{}))>; B0GridDesc_N_K{}))>;
using B1GridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(
B1GridDesc_N_K{}))>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
......
...@@ -185,7 +185,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -185,7 +185,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
GemmSpecialization::MNPadding, GemmSpecialization::MNKPadding,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
K0PerBlock, K0PerBlock,
...@@ -315,11 +315,6 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -315,11 +315,6 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
return false; return false;
} }
if(problem.K % K1 != 0)
{
return false;
}
return GridwiseGemm::CheckValidity(problem); return GridwiseGemm::CheckValidity(problem);
} }
...@@ -416,7 +411,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -416,7 +411,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << K0PerBlock << ", "
<< K1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ">" << ">"
<< " NumGemmKPrefetchStage: " << " NumGemmKPrefetchStage: "
<< NumGemmKPrefetchStage << ", " << NumGemmKPrefetchStage << ", "
......
...@@ -310,9 +310,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -310,9 +310,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
using ComputeDataType = ADataType;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
...@@ -355,14 +359,18 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -355,14 +359,18 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
LoopSched>; LoopSched>;
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map // block-to-e-tile map
using Block2ETileMap = using Block2ETileMap =
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
...@@ -60,7 +59,6 @@ template < ...@@ -60,7 +59,6 @@ template <
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default,
enable_if_t< enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
...@@ -238,8 +236,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -238,8 +236,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector>;
GemmDlAlg>;
using AGridDesc_K0_M0_M1_K1 = using AGridDesc_K0_M0_M1_K1 =
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
...@@ -276,6 +273,9 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -276,6 +273,9 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01}, N01_{N01},
M_raw_{M},
N_raw_{N},
K_raw_{K},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
...@@ -317,6 +317,10 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -317,6 +317,10 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
index_t M_raw_;
index_t N_raw_;
index_t K_raw_;
// TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being. // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
...@@ -375,8 +379,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -375,8 +379,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>, remove_reference_t<DefaultBlock2CTileMap>,
true, true,
true, true>;
GemmDlAlg>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -402,8 +405,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -402,8 +405,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>, remove_reference_t<DefaultBlock2CTileMap>,
true, true,
false, false>;
GemmDlAlg>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -429,8 +431,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -429,8 +431,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>, remove_reference_t<DefaultBlock2CTileMap>,
false, false,
true, true>;
GemmDlAlg>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -456,8 +457,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -456,8 +457,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>, remove_reference_t<DefaultBlock2CTileMap>,
false, false,
false, false>;
GemmDlAlg>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
...@@ -492,14 +492,48 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -492,14 +492,48 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8) // Make sure that the M, N, K dimensions before padding are divisible by respective vector
// lengths.
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
constexpr auto A_K_vec_length =
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I0) *
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I3);
if(arg.K_raw_ % A_K_vec_length != 0)
{
return false;
}
}
else
{
constexpr auto A_M_vec_lenght =
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I1) *
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I2);
if(arg.M_raw_ % A_M_vec_lenght != 0)
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
constexpr auto B_N_vec_lenght =
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I1) *
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I2);
if(arg.N_raw_ % B_N_vec_lenght != 0)
{
return false;
}
}
else
{ {
if(ck::get_device_name() == "gfx1030") constexpr auto B_K_vec_length =
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I0) *
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I3);
if(arg.K_raw_ % B_K_vec_length != 0)
{ {
return GridwiseGemm::CheckValidity( return false;
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
} }
return false;
} }
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t K1,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
typename M1N1ThreadClusterM1Xs,
typename M1N1ThreadClusterN1Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false>
struct DeviceGemmDlDpp8 : public DeviceGemmDl<ADataType,
BDataType,
CDataType,
AccDataType,
ALayout,
BLayout,
CLayout,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM1Xs,
M1N1ThreadClusterN1Xs,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
GemmDlAlgorithm::Dpp8>
{
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGemmDlDpp8"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< K0PerBlock << ", "
<< K1 << ", "
<< M1PerThread << ", "
<< N1PerThread << ", "
<< KPerThread
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerDpp,
ck::index_t NPerDpp,
ck::index_t MDppPerWave,
ck::index_t NDppPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_K1,
bool ABlockLdsAddExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_K1,
bool BBlockLdsAddExtraN,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector,
ck::index_t NumPrefetch = 1,
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
struct DeviceGemmDpp : public DeviceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
using GridwiseGemm = GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp<
BlockSize,
ADataType,
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
ALayout,
BLayout,
CLayout,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
MPerBlock,
NPerBlock,
KPerBlock,
MPerDpp,
NPerDpp,
AK1,
BK1,
MDppPerWave,
NDppPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
Sequence<0, 2, 4, 1, 3, 5>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
NumPrefetch,
PipelineVer>;
using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
karg.Print();
}
if(!GridwiseGemm::CheckValidity(karg))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_dpp has invalid setting");
}
const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K))
{
const auto kernel = kernel_gemm_dpp<GridwiseGemm, true>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel = kernel_gemm_dpp<GridwiseGemm, false>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& karg)
{
if(ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1100" ||
ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102")
{
return GridwiseGemm::CheckValidity(karg);
}
return false;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
{
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
// clang-format off
str << "DeviceGemmDpp"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerDpp << ", "
<< NPerDpp << ", "
<< MDppPerWave << ", "
<< MDppPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< ABlockTransferDstScalarPerVector_K1 << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< BBlockTransferDstScalarPerVector_K1
<< ">"
<< " NumPrefetch: "
<< NumPrefetch << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -364,11 +364,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -364,11 +364,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
// We have to separate mean var descriptor for gemm and layernorm bacause of different grid // We have to separate mean var descriptor for gemm and layernorm bacause of different grid
// layout(different padding) // layout(different padding)
using GemmMeanVarGridDesc_M_NBlock = decltype( using GemmMeanVarGridDesc_M_NBlock =
MakeMeanVarDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1, 1)); decltype(MakeMeanVarDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1,
1));
using GemmCountGridDesc_M_NBlock = decltype( using GemmCountGridDesc_M_NBlock =
MakeCountDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1, 1)); decltype(MakeCountDescriptor_M_N<Sequence<true, false>, GemmMPerBlock, GemmNPerBlock>(1,
1));
using LayernormMeanVarGridDesc_M_NBlock = using LayernormMeanVarGridDesc_M_NBlock =
decltype(MakeMeanVarDescriptor_M_N<Sequence<true, true>, decltype(MakeMeanVarDescriptor_M_N<Sequence<true, true>,
......
...@@ -337,10 +337,12 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle ...@@ -337,10 +337,12 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
RThreadTransferDstScalarPerVector_MPerBlock, RThreadTransferDstScalarPerVector_MPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
......
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename ABDataType, typename ADataType,
typename BDataType,
typename DsPointer, typename DsPointer,
typename EDataType, typename EDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
...@@ -36,8 +37,8 @@ __global__ void ...@@ -36,8 +37,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_multiple_d_xdl_cshuffle(const ABDataType* __restrict__ p_a_grid, kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
...@@ -143,7 +144,8 @@ template <typename ALayout, ...@@ -143,7 +144,8 @@ template <typename ALayout,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(), LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1,
typename ComputeDataType = EDataType>
struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
...@@ -244,7 +246,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -244,7 +246,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType,
BDataType,
ComputeDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
...@@ -288,14 +292,18 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -288,14 +292,18 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
PipelineVer>; PipelineVer>;
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map // block-to-e-tile map
using Block2ETileMap = using Block2ETileMap =
...@@ -438,6 +446,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -438,6 +446,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle< const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
BDataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer, typename GridwiseGemm::DsGridPointer,
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
......
...@@ -57,7 +57,10 @@ template <typename ADataType, ...@@ -57,7 +57,10 @@ template <typename ADataType,
index_t CShuffleMRepeatPerShuffle, index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXDL> index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
typename ComputeType = CDataType,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
BLayout, BLayout,
CLayout, CLayout,
...@@ -76,11 +79,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -76,11 +79,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
// TODO: should be exposed as Tparams. // TODO: should be exposed as Tparams.
static constexpr index_t NumGemmKPrefetchStage = 1; static constexpr index_t NumGemmKPrefetchStage = 1;
static constexpr LoopScheduler LoopSched = make_default_loop_scheduler(); static constexpr LoopScheduler LoopSched = make_default_loop_scheduler();
static constexpr PipelineVersion PipelineVer = PipelineVersion::v1;
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType,
BDataType,
AccDataType, AccDataType,
CDataType, CDataType,
ALayout, ALayout,
...@@ -120,7 +123,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -120,7 +123,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
CBlockTransferScalarPerVector_NWaveNPerXDL, CBlockTransferScalarPerVector_NWaveNPerXDL,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
LoopSched, LoopSched,
PipelineVer>; PipelineVer,
ComputeType>;
using Argument = typename GridwiseGemm::Argument; using Argument = typename GridwiseGemm::Argument;
using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
......
...@@ -248,10 +248,12 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout, ...@@ -248,10 +248,12 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>; CShuffleBlockTransferScalarPerVector_NPerBlock>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
......
...@@ -355,9 +355,13 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -355,9 +355,13 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
using ComputeDataType = ADataType;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
...@@ -400,14 +404,18 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -400,14 +404,18 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
LoopSched>; LoopSched>;
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
struct GroupedContractionBlock2ETileMap struct GroupedContractionBlock2ETileMap
{ {
......
...@@ -280,6 +280,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -280,6 +280,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
BK1, BK1,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock,
DoPadGemmM, DoPadGemmM,
DoPadGemmN>{}; DoPadGemmN>{};
...@@ -355,6 +356,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -355,6 +356,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ABDataType, // TODO: distinguish A/B datatype
ABDataType, // TODO: distinguish A/B datatype
ABDataType, // TODO: distinguish A/B datatype ABDataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
...@@ -422,10 +425,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ...@@ -422,10 +425,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})); using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{}));
using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{}));
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{})); decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( DsGridDesc_M_N{}));
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}));
// block-to-e-tile map // block-to-e-tile map
using Block2ETileMap = using Block2ETileMap =
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -72,6 +73,9 @@ __global__ void ...@@ -72,6 +73,9 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
...@@ -96,9 +100,23 @@ __global__ void ...@@ -96,9 +100,23 @@ __global__ void
block_2_ctile_map, block_2_ctile_map,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = batch_count;
ignore = a_grid_desc_kbatch_k0_m0_m1_k1;
ignore = b_grid_desc_kbatch_k0_n0_n1_k1;
ignore = c_grid_desc_m0_m10_m11_n0_n10_n11;
ignore = block_2_ctile_map;
ignore = compute_ptr_offset_of_batch;
#endif
} }
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
...@@ -134,29 +152,46 @@ template <ck::index_t NDimSpatial, ...@@ -134,29 +152,46 @@ template <ck::index_t NDimSpatial,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector> index_t CThreadTransferDstScalarPerVector>
struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpatial,
: public DeviceGroupedConvBwdWeight< InLayout,
NDimSpatial, WeiLayout,
ck::tuple_element_t<NDimSpatial - 1, OutLayout,
ck::Tuple<ck::tensor_layout::convolution::GNWC, InDataType,
ck::tensor_layout::convolution::GNHWC, WeiDataType,
ck::tensor_layout::convolution::GNDHWC>>, OutDataType,
ck::tuple_element_t<NDimSpatial - 1, InElementwiseOperation,
ck::Tuple<ck::tensor_layout::convolution::GKXC, WeiElementwiseOperation,
ck::tensor_layout::convolution::GKYXC, OutElementwiseOperation>
ck::tensor_layout::convolution::GKZYXC>>,
ck::tuple_element_t<NDimSpatial - 1,
ck::Tuple<ck::tensor_layout::convolution::GNWK,
ck::tensor_layout::convolution::GNHWK,
ck::tensor_layout::convolution::GNDHWK>>,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{ {
using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl; // 1d
static constexpr bool is_NWGK_GKXC_NWGC =
is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NWGK>;
static constexpr bool is_GNWK_GKXC_GNWC =
is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
// 2d
static constexpr bool is_NHWGK_GKYXC_NHWGC =
is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NHWGK>;
static constexpr bool is_GNHWK_GKYXC_GNHWC =
is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
// 3d
static constexpr bool is_NDHWGK_GKZYXC_NDHWGC =
is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NDHWGK>;
static constexpr bool is_GNDHWK_GKZYXC_GNDHWC =
is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>;
using DeviceOp = DeviceGroupedConvBwdWeight_Dl;
using ADataType = OutDataType; using ADataType = OutDataType;
using BDataType = InDataType; using BDataType = InDataType;
...@@ -176,6 +211,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -176,6 +211,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto spatial_offset = I3;
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static constexpr auto GemmK1Number = K1Number; static constexpr auto GemmK1Number = K1Number;
...@@ -195,12 +232,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -195,12 +232,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -209,90 +246,102 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -209,90 +246,102 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{ {
using namespace ck; using namespace ck;
const index_t Wi = input_spatial_lengths[0]; const index_t N = a_g_n_c_wis_lengths[I1];
const index_t Wo = output_spatial_lengths[0]; const index_t K = b_g_k_c_xs_lengths[I1];
const index_t X = filter_spatial_lengths[0]; const index_t C = a_g_n_c_wis_lengths[I2];
const index_t InLeftPadW = input_left_pads[0]; const index_t Wi = a_g_n_c_wis_lengths[spatial_offset];
const index_t InRightPadW = input_right_pads[0]; const index_t Wo = e_g_n_k_wos_lengths[spatial_offset];
const index_t ConvStrideW = conv_filter_strides[0]; const index_t X = b_g_k_c_xs_lengths[spatial_offset];
const index_t ConvDilationW = conv_filter_dilations[0]; const index_t InLeftPadW = input_left_pads[I0];
const index_t InRightPadW = input_right_pads[I0];
const index_t ConvStrideW = conv_filter_strides[I0];
const index_t ConvDilationW = conv_filter_dilations[I0];
const auto InNStride = a_g_n_c_wis_strides[I1];
const auto InCStride = a_g_n_c_wis_strides[I2];
const auto InWStride = a_g_n_c_wis_strides[spatial_offset];
const auto WeiKStride = b_g_k_c_xs_strides[I1];
const auto WeiCStride = b_g_k_c_xs_strides[I2];
const auto OutKStride = e_g_n_k_wos_strides[I2];
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset];
const index_t GemmKTotal = N * Wo; const index_t GemmKTotal = N * Wo;
const index_t GemmM = K;
const index_t GemmN = C * X;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock; K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor // B: input tensor
const auto in_gemmktotal_gemmn_grid_desc = const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C)); make_tuple(N * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weights tensor // C: weights tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
else else
{ {
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride));
const auto in_n_wi_c_grid_desc = const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); make_tuple(N, Wi, C), make_tuple(InNStride, InWStride, InCStride));
// A: output tensor // A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -321,38 +370,43 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -321,38 +370,43 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmN)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmN)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
} // function end } // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -361,103 +415,111 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -361,103 +415,111 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{ {
using namespace ck; using namespace ck;
const index_t Hi = input_spatial_lengths[0]; const index_t N = a_g_n_c_wis_lengths[I1];
const index_t Wi = input_spatial_lengths[1]; const index_t K = b_g_k_c_xs_lengths[I1];
const index_t C = a_g_n_c_wis_lengths[I2];
const index_t Ho = output_spatial_lengths[0]; const index_t Hi = a_g_n_c_wis_lengths[spatial_offset];
const index_t Wo = output_spatial_lengths[1]; const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I1];
const index_t Ho = e_g_n_k_wos_lengths[spatial_offset];
const index_t Y = filter_spatial_lengths[0]; const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I1];
const index_t X = filter_spatial_lengths[1]; const index_t Y = b_g_k_c_xs_lengths[spatial_offset];
const index_t X = b_g_k_c_xs_lengths[spatial_offset + I1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1]; const index_t InLeftPadH = input_left_pads[I0];
const index_t InLeftPadW = input_left_pads[I1];
const index_t InRightPadH = input_right_pads[0]; const index_t InRightPadH = input_right_pads[I0];
const index_t InRightPadW = input_right_pads[1]; const index_t InRightPadW = input_right_pads[I1];
const index_t ConvStrideH = conv_filter_strides[I0];
const index_t ConvStrideH = conv_filter_strides[0]; const index_t ConvStrideW = conv_filter_strides[I1];
const index_t ConvStrideW = conv_filter_strides[1]; const index_t ConvDilationH = conv_filter_dilations[I0];
const index_t ConvDilationW = conv_filter_dilations[I1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1]; const auto InNStride = a_g_n_c_wis_strides[I1];
const auto InCStride = a_g_n_c_wis_strides[I2];
const auto InHStride = a_g_n_c_wis_strides[spatial_offset];
const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I1];
const auto WeiKStride = b_g_k_c_xs_strides[I1];
const auto WeiCStride = b_g_k_c_xs_strides[I2];
const auto OutKStride = e_g_n_k_wos_strides[I2];
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I1];
const index_t GemmKTotal = N * Ho * Wo; const index_t GemmKTotal = N * Ho * Wo;
const index_t GemmM = K;
const index_t GemmN = C * X * Y;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock; K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor // B: input tensor
const auto in_gemmktotal_gemmn_grid_desc = const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C)); make_tuple(N * Hi * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
else else
{ {
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto in_n_hi_wi_c_grid_desc = const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); make_tuple(N, Hi, Wi, C), make_tuple(InNStride, InHStride, InWStride, InCStride));
// A: output tensor // A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -488,39 +550,44 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -488,39 +550,44 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmN)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmN)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
} // function end } // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const ck::index_t N, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const ck::index_t K, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::index_t C, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -529,110 +596,120 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -529,110 +596,120 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
{ {
using namespace ck; using namespace ck;
const index_t Di = input_spatial_lengths[0]; const index_t N = a_g_n_c_wis_lengths[I1];
const index_t Hi = input_spatial_lengths[1]; const index_t K = b_g_k_c_xs_lengths[I1];
const index_t Wi = input_spatial_lengths[2]; const index_t C = a_g_n_c_wis_lengths[I2];
const index_t Di = a_g_n_c_wis_lengths[spatial_offset + I0];
const index_t Do = output_spatial_lengths[0]; const index_t Hi = a_g_n_c_wis_lengths[spatial_offset + I1];
const index_t Ho = output_spatial_lengths[1]; const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I2];
const index_t Wo = output_spatial_lengths[2]; const index_t Do = e_g_n_k_wos_lengths[spatial_offset + I0];
const index_t Ho = e_g_n_k_wos_lengths[spatial_offset + I1];
const index_t Z = filter_spatial_lengths[0]; const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I2];
const index_t Y = filter_spatial_lengths[1]; const index_t Z = b_g_k_c_xs_lengths[spatial_offset + I0];
const index_t X = filter_spatial_lengths[2]; const index_t Y = b_g_k_c_xs_lengths[spatial_offset + I1];
const index_t X = b_g_k_c_xs_lengths[spatial_offset + I2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1]; const index_t InLeftPadD = input_left_pads[I0];
const index_t InLeftPadW = input_left_pads[2]; const index_t InLeftPadH = input_left_pads[I1];
const index_t InLeftPadW = input_left_pads[I2];
const index_t InRightPadD = input_right_pads[0]; const index_t InRightPadD = input_right_pads[I0];
const index_t InRightPadH = input_right_pads[1]; const index_t InRightPadH = input_right_pads[I1];
const index_t InRightPadW = input_right_pads[2]; const index_t InRightPadW = input_right_pads[I2];
const index_t ConvStrideD = conv_filter_strides[I0];
const index_t ConvStrideD = conv_filter_strides[0]; const index_t ConvStrideH = conv_filter_strides[I1];
const index_t ConvStrideH = conv_filter_strides[1]; const index_t ConvStrideW = conv_filter_strides[I2];
const index_t ConvStrideW = conv_filter_strides[2]; const index_t ConvDilationD = conv_filter_dilations[I0];
const index_t ConvDilationH = conv_filter_dilations[I1];
const index_t ConvDilationD = conv_filter_dilations[0]; const index_t ConvDilationW = conv_filter_dilations[I2];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2]; const auto InNStride = a_g_n_c_wis_strides[I1];
const auto InCStride = a_g_n_c_wis_strides[I2];
const auto InDStride = a_g_n_c_wis_strides[spatial_offset];
const auto InHStride = a_g_n_c_wis_strides[spatial_offset + I1];
const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I2];
const auto WeiKStride = b_g_k_c_xs_strides[I1];
const auto WeiCStride = b_g_k_c_xs_strides[I2];
const auto OutKStride = e_g_n_k_wos_strides[I2];
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I2];
const index_t GemmKTotal = N * Do * Ho * Wo; const index_t GemmKTotal = N * Do * Ho * Wo;
const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y;
const index_t GemmKBatch = batch_k; const index_t GemmKBatch = batch_k;
const index_t GemmK0 = const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock; K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
if constexpr(ConvBackwardWeightSpecialization == if constexpr(ConvBackwardWeightSpecialization ==
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{ {
// A: output tensor // A: output tensor
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// B: input tensor // B: input tensor
const auto in_gemmktotal_gemmn_grid_desc = const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C)); make_tuple(N * Di * Hi * Wi, C), make_tuple(InWStride, InCStride));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
else else
{ {
const auto out_gemmktotal_gemmm_grid_desc = const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
const auto in_n_di_hi_wi_c_grid_desc = const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); make_tuple(N, Di, Hi, Wi, C),
make_tuple(InNStride, InDStride, InHStride, InWStride, InCStride));
// A: output tensor // A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( const auto out_gemmkpad_gemmmpad_grid_desc =
out_gemmktotal_gemmm_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), out_gemmktotal_gemmm_grid_desc,
make_pass_through_transform(GemmM)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc, out_gemmkpad_gemmmpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmM)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
...@@ -672,27 +749,32 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -672,27 +749,32 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<1>{}, Sequence<0>{})); make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( const auto in_gemmkpad_gemmnpad_grid_desc =
in_gemmktotal_gemmn_grid_desc, ck::tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), in_gemmktotal_gemmn_grid_desc,
make_pass_through_transform(GemmN)), make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
make_tuple(Sequence<0>{}, Sequence<1>{}), Sequence<true, true>{});
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc, in_gemmkpad_gemmnpad_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), make_tuple(
make_pass_through_transform(GemmN)), make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
// C: weight tensor // C: weight tensor
const auto wei_gemmm_gemmn_grid_desc = const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride));
const auto wei_gemmmpad_gemmnpad_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
make_tuple(MPerBlock, NPerBlock),
Sequence<true, true>{});
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_gemmm_gemmn_grid_desc); wei_gemmmpad_gemmnpad_grid_desc);
} }
} // function end } // function end
...@@ -701,22 +783,22 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -701,22 +783,22 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static auto GetABCGridDesc() static auto GetABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1);
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto GetABCGridDesc() static auto GetABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1);
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto GetABCGridDesc() static auto GetABCGridDesc()
{ {
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>({1, 1, 1},
1, {1, 1, 1},
1, {1, 1, 1},
{1, 1, 1}, {1, 1, 1},
{1, 1, 1}, {1, 1, 1},
{1, 1, 1}, {1, 1, 1},
...@@ -785,11 +867,11 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -785,11 +867,11 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
WeiDataType* p_wei_grid, WeiDataType* p_wei_grid,
const OutDataType* p_out_grid, const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& /*a_g_n_c_wis_strides*/, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& /*b_g_k_c_xs_strides*/, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& /*e_g_n_k_wos_strides*/, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides, const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations, const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads, const std::array<ck::index_t, NDimSpatial>& input_left_pads,
...@@ -809,38 +891,24 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -809,38 +891,24 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
a_element_op_{out_element_op}, a_element_op_{out_element_op},
b_element_op_{wei_element_op}, b_element_op_{wei_element_op},
c_element_op_{in_element_op}, c_element_op_{in_element_op},
Conv_G_{a_g_n_c_wis_lengths[0]}, Conv_G_{a_g_n_c_wis_lengths[I0]},
Conv_N_{a_g_n_c_wis_lengths[1]}, Conv_K_{b_g_k_c_xs_lengths[I1]},
Conv_K_{b_g_k_c_xs_lengths[1]}, Conv_C_{a_g_n_c_wis_lengths[I2]},
Conv_C_{a_g_n_c_wis_lengths[2]}, filter_lengths_{b_g_k_c_xs_lengths},
input_spatial_lengths_{},
filter_spatial_lengths_{},
output_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations}, conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}, input_right_pads_{input_right_pads},
k_batch_{split_k} k_batch_{split_k}
{ {
constexpr index_t spatial_offset = 3;
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
end(a_g_n_c_wis_lengths),
begin(input_spatial_lengths_));
std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset,
end(b_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset,
end(e_g_n_k_wos_lengths),
begin(output_spatial_lengths_));
const auto descs = const auto descs =
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>( DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
Conv_N_, a_g_n_c_wis_lengths, // input
Conv_K_, a_g_n_c_wis_strides,
Conv_C_, b_g_k_c_xs_lengths, // weight
input_spatial_lengths_, b_g_k_c_xs_strides,
filter_spatial_lengths_, e_g_n_k_wos_lengths, // output
output_spatial_lengths_, e_g_n_k_wos_strides,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -863,24 +931,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -863,24 +931,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
// A/B/C Batch Stride // A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[I0];
Conv_N_ * Conv_K_ * compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[I0];
std::accumulate(begin(output_spatial_lengths_), compute_ptr_offset_of_batch_.BatchStrideC_ = b_g_k_c_xs_strides[I0];
end(output_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideB_ =
Conv_N_ * Conv_C_ *
std::accumulate(begin(input_spatial_lengths_),
end(input_spatial_lengths_),
index_t{1},
std::multiplies<>{});
compute_ptr_offset_of_batch_.BatchStrideC_ =
Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
} }
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
...@@ -908,13 +961,10 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -908,13 +961,10 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
const index_t Conv_G_; const index_t Conv_G_;
const index_t Conv_N_;
const index_t Conv_K_; const index_t Conv_K_;
const index_t Conv_C_; const index_t Conv_C_;
std::array<ck::index_t, NDimSpatial> input_spatial_lengths_; std::array<ck::index_t, NDimSpatial + 3> filter_lengths_;
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_; const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_; const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_; const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
...@@ -1036,10 +1086,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1036,10 +1086,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || // DL version only supports split_k equal to 1
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(arg.k_batch_ != 1)
ck::get_device_name() == "gfx1102")) return false;
if constexpr(!((NDimSpatial == 1 && (is_NWGK_GKXC_NWGC || is_GNWK_GKXC_GNWC)) ||
(NDimSpatial == 2 && (is_NHWGK_GKYXC_NHWGC || is_GNHWK_GKYXC_GNHWC)) ||
(NDimSpatial == 3 && (is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC))))
{ {
return false; return false;
} }
...@@ -1050,8 +1104,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1050,8 +1104,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// check if it's 1x1, stride=1 pad = 0 conv // check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NDimSpatial; i++) for(int i = 0; i < NDimSpatial; i++)
{ {
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && if(!(arg.filter_lengths_[spatial_offset + i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) arg.conv_filter_strides_[i] == 1 && arg.input_left_pads_[i] == 0 &&
arg.input_right_pads_[i] == 0))
{ {
return false; return false;
} }
...@@ -1206,7 +1261,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl ...@@ -1206,7 +1261,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl" str << "DeviceGroupedConvBwdWeight_Dl"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -381,8 +381,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -381,8 +381,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
} }
// desc for problem definition // desc for problem definition
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
MakeAGridDescriptor_AK0_M_AK1<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_BK0_N_BK1 = using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>; remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
......
...@@ -320,8 +320,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -320,8 +320,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
} }
// desc for problem definition // desc for problem definition
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
MakeAGridDescriptor_AK0_M_AK1<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_BK0_N_BK1 = using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>; remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N<CLayout>({}, {}))>; using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N<CLayout>({}, {}))>;
......
...@@ -446,8 +446,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -446,8 +446,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo); return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
} }
using AGridDesc_M_K = remove_cvref_t<decltype( using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>; using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<DELayout>({}, {}))>; using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<DELayout>({}, {}))>;
using RGridDesc_M = remove_cvref_t<decltype(MakeRGridDescriptor_M<RLayout>({}, {}))>; using RGridDesc_M = remove_cvref_t<decltype(MakeRGridDescriptor_M<RLayout>({}, {}))>;
...@@ -507,10 +507,12 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -507,10 +507,12 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
RThreadTransferDstScalarPerVector_MPerBlock, RThreadTransferDstScalarPerVector_MPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
......
...@@ -245,8 +245,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -245,8 +245,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
} }
// desc for problem definition // desc for problem definition
using AGridDesc_M_K = remove_cvref_t<decltype( using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>; using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
...@@ -599,7 +599,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -599,7 +599,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i) for(index_t i = 0; i < NDimSpatial; ++i)
{ {
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t ConvStride = arg.conv_filter_strides_[i]; const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i]; const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i]; const index_t RightPad = arg.input_right_pads_[i];
...@@ -616,7 +616,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -616,7 +616,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// check if it's 1x1 conv // check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i) for(index_t i = 0; i < NDimSpatial; ++i)
{ {
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i]; const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i]; const index_t RightPad = arg.input_right_pads_[i];
......
...@@ -361,15 +361,19 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -361,15 +361,19 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
} }
// desc for problem definition // desc for problem definition
using AGridDesc_M_K = remove_cvref_t<decltype( using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
MakeAGridDescriptor_M_K<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>; using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>; using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
using ComputeDataType = ADataType;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
BDataType,
ComputeDataType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
...@@ -412,14 +416,18 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle ...@@ -412,14 +416,18 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
LoopSched>; LoopSched>;
// desc for blockwise copy // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( AGridDesc_M_K{}))>;
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; using BGridDesc_BK0_N_BK1 =
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map // block-to-e-tile map
using Block2ETileMap = using Block2ETileMap =
......
...@@ -735,12 +735,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -735,12 +735,12 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
} }
// Check vector load/store requirement // Check vector load/store requirement
const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2 const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2
? device_arg.a_mz_kz_strides_[1] ? device_arg.a_mz_kz_strides_[1]
: device_arg.a_mz_kz_strides_[0]; : device_arg.a_mz_kz_strides_[0];
const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2 const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2
? device_arg.b_nz_kz_strides_[1] ? device_arg.b_nz_kz_strides_[1]
: device_arg.b_nz_kz_strides_[0]; : device_arg.b_nz_kz_strides_[0];
const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2 const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2
? device_arg.b1_nz_kz_strides_[1] ? device_arg.b1_nz_kz_strides_[1]
: device_arg.b1_nz_kz_strides_[0]; : device_arg.b1_nz_kz_strides_[0];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment