Commit f6ceef78 authored by ThomasNing's avatar ThomasNing
Browse files

merge with the develop branch

parents 536c5458 25935b57
......@@ -359,14 +359,14 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
__host__ __device__ static auto
MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
......@@ -379,7 +379,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename BLay>
__host__ __device__ static auto
MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
......@@ -392,7 +392,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename ELay>
__host__ __device__ static auto
MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
......@@ -405,7 +405,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast.
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
return generate_tuple(
[&](auto i) {
......@@ -417,7 +417,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
// desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K =
......@@ -617,7 +617,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
......@@ -686,7 +686,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
......@@ -727,6 +727,181 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ck::Array<index_t, NDimSpatial> input_right_pads_;
};
static __device__ __host__ bool IsSupportedArgument(const Argument& arg)
{
namespace ctc = tensor_layout::convolution;
// check ConvolutionForwardSpecialization
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
}
}
// check vector access of A
// FIXME: layout
if constexpr(is_same_v<ALayout, ctc::G_NW_C> || is_same_v<ALayout, ctc::G_NHW_C> ||
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC>)
{
const index_t C = arg.a_g_n_c_wis_lengths_[2];
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// check vector access of B
// FIXME: layout
if constexpr(is_same_v<BLayout, ctc::G_K_X_C> || is_same_v<BLayout, ctc::G_K_YX_C> ||
is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
is_same_v<BLayout, ctc::KZYXGC>)
{
const index_t C = arg.b_g_k_c_xs_lengths_[2];
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// check vector access of Ds
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
// FIXME: layout
if constexpr(is_same_v<DLayout, ctc::G_NW_K> || is_same_v<DLayout, ctc::G_NHW_K> ||
is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> ||
is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> ||
is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> ||
is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::G_K>)
{
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
valid = false;
}
if constexpr(is_same_v<DLayout, ctc::G_K>)
{
// G and K must be the same
if(arg.ds_g_n_k_wos_lengths_[i][0] != arg.e_g_n_k_wos_lengths_[0] ||
arg.ds_g_n_k_wos_lengths_[i][2] != arg.e_g_n_k_wos_lengths_[2])
{
valid = false;
}
}
else
{
// E and D must have the same shape
for(index_t d = 0; d < NDimSpatial + 3; d++)
{
if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d])
{
valid = false;
}
}
}
}
else
{
valid = false;
}
});
if(!valid)
{
return false;
}
// check vector access of E
if constexpr(is_same_v<ELayout, ctc::G_NW_K> || is_same_v<ELayout, ctc::G_NHW_K> ||
is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
is_same_v<ELayout, ctc::NDHWGK>)
{
const index_t K = arg.e_g_n_k_wos_lengths_[2];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
}
else
{
return false;
}
// check Gridwise GEMM
if constexpr(isMultiA || isMultiB)
{
// Genarate tuples with the same descriptors
const auto as_grid_desc_ak0_m_ak1 =
generate_tuple([&](auto) { return arg.a_grid_desc_m_k_; }, Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 =
generate_tuple([&](auto) { return arg.b_grid_desc_n_k_; }, Number<NumBTensor>{});
return GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
else
{
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
}
static __device__ __host__ auto MakeArgument(
APointers p_as,
BPointers p_bs,
......@@ -768,6 +943,77 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
b_element_op,
cde_element_op};
}
static __device__ __host__ auto MakeArgument(
APointers p_as,
BPointers p_bs,
const ck::Array<const void*, NumDTensor>& p_ds,
void* p_e,
const ck::Array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const ck::Array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const ck::Array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const ck::Array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const ck::Array<ck::Array<long_index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const ck::Array<ck::Array<long_index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const ck::Array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const ck::Array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const ck::Array<long_index_t, NDimSpatial>& conv_filter_strides,
const ck::Array<long_index_t, NDimSpatial>& conv_filter_dilations,
const ck::Array<long_index_t, NDimSpatial>& input_left_pads,
const ck::Array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op};
}
};
} // namespace device
......
......@@ -64,7 +64,7 @@ struct DeviceColumnToImageImpl
static constexpr auto spatial_offset = Number<3>{};
using GemmToConvFwdTransformer =
using ConvToGemmFwdTransformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{
......@@ -233,7 +233,7 @@ struct DeviceColumnToImageImpl
: independent_filter_stride;
}
GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
image_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor
......
......@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace ck {
namespace tensor_operation {
......@@ -68,17 +69,17 @@ template <typename ALayout,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB>
struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
......@@ -163,14 +164,61 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
if(stream_config.flush_cache)
{
std::array<std::size_t, NumDTensor> DsSize;
Argument arg_ = arg;
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
});
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
arg_,
stream_config.rotating_count,
arg_.M * arg_.K * sizeof(ADataType),
arg_.K * arg_.N * sizeof(BDataType),
DsSize);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_);
}
else
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
constexpr index_t minimum_occupancy =
......@@ -182,38 +230,161 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
......@@ -221,12 +392,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
......@@ -236,12 +407,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
......@@ -251,12 +422,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
......@@ -266,12 +437,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
......@@ -280,12 +451,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
......@@ -295,12 +466,12 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
......@@ -309,51 +480,98 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
......@@ -364,12 +582,22 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
kernel_gemm_xdl_cshuffle_v3_multi_d<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
......@@ -399,6 +627,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
......@@ -427,6 +660,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
......@@ -442,7 +676,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
StrideB,
StrideDs,
StrideC,
1,
KBatch,
a_element_op,
b_element_op,
c_element_op};
......@@ -462,6 +696,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
......@@ -477,7 +712,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
StrideB,
StrideDs,
StrideC,
1,
KBatch,
a_element_op,
b_element_op,
c_element_op);
......
......@@ -168,15 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// rotating mem
rotating_mem.Next();
// clear c mem
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(arg_.KBatch > 1)
hipGetErrorString(
hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
}
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
......@@ -190,14 +186,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
}
else
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
}
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
......@@ -215,15 +208,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
......@@ -240,118 +230,113 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
TailNumber::Two>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
......@@ -473,28 +458,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
......@@ -525,28 +507,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
......@@ -579,18 +558,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
......@@ -628,6 +603,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
......
......@@ -238,14 +238,14 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
......@@ -266,7 +266,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
template <typename BLay>
static auto
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
......@@ -287,7 +287,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
......@@ -298,7 +298,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
return out_gemmm_gemmn_desc;
}
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
return generate_tuple(
[&](auto i) {
......@@ -310,7 +310,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
}
// desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
......@@ -447,7 +447,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
......@@ -511,7 +511,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
......@@ -836,6 +836,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
cde_element_op};
}
static auto
MakeArgument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
......@@ -880,6 +953,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
cde_element_op);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
......
......@@ -234,14 +234,14 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
......@@ -263,7 +263,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
template <typename BLay>
static auto
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
......@@ -284,7 +284,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
}
template <typename CLay>
static auto MakeCGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeCGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>();
......@@ -296,7 +296,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
}
// desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
......@@ -452,7 +452,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
......
......@@ -86,7 +86,6 @@ __global__ void
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t groups_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -101,10 +100,8 @@ __global__ void
defined(__gfx94__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
......@@ -200,7 +197,6 @@ __global__ void
ignore = p_bs_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = groups_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
......@@ -318,17 +314,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ALayout,
ELayout>;
ADataType,
EDataType>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
......@@ -340,7 +336,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
template <typename BLay>
static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
......@@ -352,7 +348,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
......@@ -365,7 +361,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast.
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
return generate_tuple(
[&](auto i) {
......@@ -377,7 +373,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
// desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K =
......@@ -591,7 +587,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_;
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
......@@ -670,7 +666,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
index_t conv_N_per_block_;
......@@ -730,8 +726,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const index_t gdy = arg.num_group_ * num_workgroups_per_Conv_N;
const index_t gdz = 1;
const index_t gdy = arg.num_group_;
const index_t gdz = num_workgroups_per_Conv_N;
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......@@ -780,7 +776,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
......@@ -824,7 +819,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
......@@ -1088,11 +1082,84 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
cde_element_op};
}
static auto
MakeArgument(APointers p_as,
BPointers p_bs,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
APointers p_a,
BPointers p_b,
APointers p_as,
BPointers p_bs,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
......@@ -1111,8 +1178,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
return std::make_unique<Argument>(p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths,
......@@ -1132,6 +1199,80 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
cde_element_op);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(APointers p_as,
BPointers p_bs,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return std::make_unique<Argument>(p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
......
......@@ -293,7 +293,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
......@@ -304,7 +304,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <typename ALay>
static auto
MakeAGridDescriptor_AK0_M_AK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
......@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <typename BLay>
static auto
MakeBGridDescriptor_BK0_N_BK1(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
......@@ -348,7 +348,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
......@@ -361,7 +361,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
// desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
......@@ -495,7 +495,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
index_t conv_N_per_block_;
......@@ -978,6 +978,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return false;
}
// Gridwise gemm v3 doesn't verify descriptors size
if(!arg.conv_to_gemm_transformer_.AreDescriptorsSmallerThan2GB())
{
return false;
}
// check Gridwise GEMM
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1);
......@@ -1037,6 +1043,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
cde_element_op};
}
static auto
MakeArgument(const void* p_as,
const void* p_bs,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
......@@ -1081,6 +1160,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
cde_element_op);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
......
......@@ -309,13 +309,13 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto MakeAGridDescriptor_M_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
......@@ -327,7 +327,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
}
template <typename BLay>
static auto MakeBGridDescriptor_N_K(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
......@@ -339,7 +339,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
......@@ -420,7 +420,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
}
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K =
......@@ -599,7 +599,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
// D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
......@@ -649,7 +649,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
EDataType* p_e_grid_;
typename GridwiseGemm::RsGridPointer p_rs_grid_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
......
......@@ -135,13 +135,13 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static constexpr auto BEnableLds =
BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1);
using GemmToConvFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto MakeAGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeAGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
......@@ -185,7 +185,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
template <typename BLay>
static auto MakeBGridDescriptor(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeBGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
......@@ -229,7 +229,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
......@@ -240,7 +240,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
return out_gemmm_gemmn_desc;
}
static auto MakeDsGridDescriptor_M_N(const GemmToConvFwdTransformer& conv_to_gemm_transformer)
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
return generate_tuple(
[&](auto i) {
......@@ -252,7 +252,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
// desc for problem definition
constexpr static GemmToConvFwdTransformer dummy_conv_to_gemm_transformer;
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc =
decltype(DeviceOp::MakeAGridDescriptor<ALayout>(dummy_conv_to_gemm_transformer));
using BGridDesc =
......@@ -406,7 +406,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
GemmToConvFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
......@@ -448,7 +448,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// tensor descriptors for problem definiton
index_t num_group_;
GemmToConvFwdTransformer conv_to_gemm_transformer_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
......@@ -772,6 +772,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
cde_element_op};
}
static auto
MakeArgument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
1,
1,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
......@@ -818,6 +893,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
cde_element_op);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
1,
1,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <queue>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace {
template <typename GridwiseGemm,
index_t MaxGemmsNum,
typename GemmArgs,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename ComputePtrOffset,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle(
Array<GemmArgs, MaxGemmsNum> gemm_desc_kernel_args,
const index_t gemms_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op,
const ComputePtrOffset compute_ptr_offset_of_groups,
const ComputePtrOffset compute_ptr_offset_of_n)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const long_index_t a_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t e_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
index_t left = 0;
index_t right = gemms_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_id_x >= gemm_desc_kernel_args[group_id].BlockStart_ &&
block_id_x < gemm_desc_kernel_args[group_id].BlockEnd_)) &&
left <= right)
{
if(block_id_x < gemm_desc_kernel_args[group_id].BlockStart_)
{
right = group_id;
}
else
{
left = group_id;
}
group_id = index_t((left + right) / 2);
}
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset,
gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset,
Tuple<>{},
gemm_desc_kernel_args[group_id].e_ptr_ + e_group_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_desc_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
Tuple<>{},
gemm_desc_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_kernel_args[group_id].block_2_etile_map_);
#else
ignore = gemm_desc_kernel_args;
ignore = gemms_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = compute_ptr_offset_of_groups;
ignore = compute_ptr_offset_of_n;
#endif
}
} // namespace
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename AComputeDataType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>,
ADataType>()), // ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
typename BComputeDataType = AComputeDataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
AComputeDataType,
BComputeDataType>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t MaxGemmsNum = 32;
static_assert(NumDTensor == 0, "MultiD not supported.");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ConvToGemmFwdTransformerIndexT = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
EDataType,
I1,
index_t>;
using ConvToGemmFwdTransformerLongIndexT = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
EDataType,
I1,
long_index_t>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
return in_gemmm_gemmk_desc;
}
template <typename BLay>
static auto
MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
return wei_gemmn_gemmk_desc;
}
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
return out_gemmm_gemmn_desc;
}
// desc for problem definition
constexpr static ConvToGemmFwdTransformerIndexT dummy_conv_to_gemm_transformer;
using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K =
remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
static auto
GenerateConvToGemmTransforms(ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformer_base,
const ADataType* a_grid_ptr_base,
EDataType* c_grid_ptr_base)
{
// Max number of splits
// We need to use it to avoid infinity loop
constexpr index_t max_split_numbers = MaxGemmsNum / 2;
// Arrays to store transformers with smaller descs than 2GB
Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformers_arr;
Array<const ADataType*, MaxGemmsNum> a_grid_ptrs_arr;
Array<EDataType*, MaxGemmsNum> c_grid_ptrs_arr;
// Queue for spliting
std::queue<ConvToGemmFwdTransformerLongIndexT> conv_to_gemm_transformers_queue(
{conv_to_gemm_transformer_base});
std::queue<const ADataType*> a_grid_ptrs_queue({a_grid_ptr_base});
std::queue<EDataType*> c_grid_ptrs_queue({c_grid_ptr_base});
index_t gemms_number = 0;
index_t split_numbers = 0;
// Algorithm:
// While queue is not empty:
// 1. Get transformer from queue.
// 2. If descs are smaller than 2GB push to result array.
// 3. If descs are bigger than 2GB split into left and right transformer.
// and push the both into the queue.
while(!conv_to_gemm_transformers_queue.empty() && split_numbers < max_split_numbers &&
gemms_number < MaxGemmsNum)
{
// Get transformer from the queue
const auto& conv_to_gemm_transformer = conv_to_gemm_transformers_queue.front();
const ADataType* a_grid_ptr = a_grid_ptrs_queue.front();
EDataType* c_grid_ptr = c_grid_ptrs_queue.front();
// Check if convolution not exceed 2GB
if(conv_to_gemm_transformer.AreDescriptorsSmallerThan2GB())
{
// If yes, push into result array
conv_to_gemm_transformers_arr(gemms_number) =
ConvToGemmFwdTransformerIndexT{conv_to_gemm_transformer};
a_grid_ptrs_arr(gemms_number) = a_grid_ptr;
c_grid_ptrs_arr(gemms_number) = c_grid_ptr;
gemms_number++;
}
else
{
// If no, split into left and right convolutions
ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformers_left_part,
conv_to_gemm_transformers_right_part;
const ADataType* a_grid_right_ptr;
EDataType* c_grid_right_ptr;
ck::tie(conv_to_gemm_transformers_left_part,
conv_to_gemm_transformers_right_part,
a_grid_right_ptr,
c_grid_right_ptr) =
conv_to_gemm_transformer.SplitConvProblem(a_grid_ptr, c_grid_ptr);
conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_left_part);
conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_right_part);
// Left offsets remain the same
a_grid_ptrs_queue.push(a_grid_ptr);
a_grid_ptrs_queue.push(a_grid_right_ptr);
c_grid_ptrs_queue.push(c_grid_ptr);
c_grid_ptrs_queue.push(c_grid_right_ptr);
split_numbers++;
}
// Remove from the queue
conv_to_gemm_transformers_queue.pop();
a_grid_ptrs_queue.pop();
c_grid_ptrs_queue.pop();
}
const bool is_split_valid = conv_to_gemm_transformers_queue.empty();
return ck::make_tuple(conv_to_gemm_transformers_arr,
a_grid_ptrs_arr,
c_grid_ptrs_arr,
gemms_number,
is_split_valid);
}
#define GridwiseGemmTemplateParameters \
ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
AComputeDataType
// Use appropriate gridwise gemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Structure for each gemm(conv)
struct GemmArgs
{
// pointers
const ADataType* a_ptr_;
const BDataType* b_ptr_;
EDataType* e_ptr_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
ck::index_t BlockStart_, BlockEnd_;
};
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& /*p_ds*/,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
/*ds_g_n_k_wos_lengths*/,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
/*ds_g_n_k_wos_strides*/,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
: num_group_{static_cast<index_t>(a_g_n_c_wis_lengths[0])},
compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
// Perform grouped gemm, generate array of tranformer for convolution
Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformer_arr;
Array<const ADataType*, MaxGemmsNum> a_grid_ptrs;
Array<EDataType*, MaxGemmsNum> c_grid_ptrs;
ck::tie(conv_to_gemm_transformer_arr,
a_grid_ptrs,
c_grid_ptrs,
gemms_count_,
is_split_valid_) =
GenerateConvToGemmTransforms(
ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_,
a_g_n_c_wis_strides_,
b_g_k_c_xs_lengths_,
b_g_k_c_xs_strides_,
e_g_n_k_wos_lengths_,
e_g_n_k_wos_strides_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_},
static_cast<const ADataType*>(p_a),
static_cast<EDataType*>(p_e));
grid_size_ = 0;
valid_gemms_count_ = 0;
if(is_split_valid_)
{
// Create GemmArg for each gemm(conv)
for(index_t i = 0; i < gemms_count_; i++)
{
const AGridDesc_M_K a_grid_desc_m_k{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(
conv_to_gemm_transformer_arr[i])};
const BGridDesc_N_K b_grid_desc_n_k{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(
conv_to_gemm_transformer_arr[i])};
const auto e_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_arr[i]);
const auto block_2_etile_map =
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
const index_t grid_size_grp =
block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
Tuple<>{},
e_grid_desc_m_n,
block_2_etile_map))
{
gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{
a_grid_ptrs[i],
static_cast<const BDataType*>(p_b),
c_grid_ptrs[i],
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k),
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k),
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n),
block_2_etile_map,
BlockStart,
BlockEnd};
valid_gemms_count_++;
}
}
// N is the same for all convs
conv_N_per_block_ = static_cast<index_t>(conv_to_gemm_transformer_arr[I0].N_);
}
// Strides for G and N remain the same
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
}
void Print() const
{
for(index_t i = 0; i < valid_gemms_count_; i++)
{
std::cout << "A[AK0, M, AK1]: " << gemm_desc_kernel_args_[i].a_grid_desc_ak0_m_ak1_
<< std::endl;
std::cout << "B[BK0, N, BK1]: " << gemm_desc_kernel_args_[i].b_grid_desc_bk0_n_bk1_
<< std::endl;
std::cout
<< "E[MBlock, MPerBlock, NBlock, NPerBlock]: "
<< gemm_desc_kernel_args_[i].e_grid_desc_mblock_mperblock_nblock_nperblock_
<< std::endl;
}
}
index_t num_group_;
index_t conv_N_per_block_;
Array<GemmArgs, MaxGemmsNum> gemm_desc_kernel_args_;
index_t grid_size_;
index_t gemms_count_;
index_t valid_gemms_count_;
bool is_split_valid_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_groups_;
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// for checking IsSupportedArgument()
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
std::array<long_index_t, NDimSpatial> conv_filter_strides_;
std::array<long_index_t, NDimSpatial> conv_filter_dilations_;
std::array<long_index_t, NDimSpatial> input_left_pads_;
std::array<long_index_t, NDimSpatial> input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const DeviceOp::Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdx = arg.grid_size_;
const index_t gdy = arg.num_group_;
const index_t gdz = num_workgroups_per_Conv_N;
// K is constant for all gemms
const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle<
GridwiseGemm,
MaxGemmsNum,
GemmArgs,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.gemm_desc_kernel_args_,
arg.gemms_count_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
namespace ctc = tensor_layout::convolution;
const long_index_t K = arg.b_g_k_c_xs_lengths_[I1];
const long_index_t C = arg.b_g_k_c_xs_lengths_[I2];
// Check if all descs are valid
if(!(arg.is_split_valid_ && arg.gemms_count_ == arg.valid_gemms_count_))
{
return false;
}
// check device
if(get_device_name() == "gfx908")
{
// FIXME: re-enable fp64 when SWDEV-335738 is fixed
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
return false;
}
}
if(!ck::is_xdl_supported())
{
return false;
}
// check ConvolutionForwardSpecialization
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
}
}
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3)
{
if(C != 1)
{
return false;
}
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3];
if(filter_spatial_dim != I3)
{
return false;
}
}
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
{
return false;
}
}
// check vector access of A
// FIXME: layout
if constexpr(is_same_v<ALayout, ctc::G_NW_C> || is_same_v<ALayout, ctc::G_NHW_C> ||
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC>)
{
// Check access per C
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// check vector access of B
// FIXME: layout
if constexpr(is_same_v<BLayout, ctc::G_K_X_C> || is_same_v<BLayout, ctc::G_K_YX_C> ||
is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
is_same_v<BLayout, ctc::KZYXGC>)
{
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// check vector access of E
if constexpr(is_same_v<ELayout, ctc::G_NW_K> || is_same_v<ELayout, ctc::G_NHW_K> ||
is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
is_same_v<ELayout, ctc::NDHWGK>)
{
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
}
else
{
return false;
}
return true;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(
const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i64;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i64;
std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i64;
std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i64;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i64;
std::array<long_index_t, NDimSpatial> conv_filter_strides_i64;
std::array<long_index_t, NDimSpatial> conv_filter_dilations_i64;
std::array<long_index_t, NDimSpatial> input_left_pads_i64;
std::array<long_index_t, NDimSpatial> input_right_pads_i64;
array_convert(a_g_n_c_wis_lengths_i64, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i64, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i64, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i64, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i64[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i64[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i64, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i64, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i64, conv_filter_strides);
array_convert(conv_filter_dilations_i64, conv_filter_dilations);
array_convert(input_left_pads_i64, input_left_pads);
array_convert(input_right_pads_i64, input_right_pads);
return Argument{p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i64,
a_g_n_c_wis_strides_i64,
b_g_k_c_xs_lengths_i64,
b_g_k_c_xs_strides_i64,
ds_g_n_k_wos_lengths_i64,
ds_g_n_k_wos_strides_i64,
e_g_n_k_wos_lengths_i64,
e_g_n_k_wos_strides_i64,
conv_filter_strides_i64,
conv_filter_dilations_i64,
input_left_pads_i64,
input_right_pads_i64,
a_element_op,
b_element_op,
cde_element_op};
}
static auto
MakeArgument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths,
ds_g_n_k_wos_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i64;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i64;
std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i64;
std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i64;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i64;
std::array<long_index_t, NDimSpatial> conv_filter_strides_i64;
std::array<long_index_t, NDimSpatial> conv_filter_dilations_i64;
std::array<long_index_t, NDimSpatial> input_left_pads_i64;
std::array<long_index_t, NDimSpatial> input_right_pads_i64;
array_convert(a_g_n_c_wis_lengths_i64, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i64, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i64, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i64, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i64[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i64[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i64, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i64, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i64, conv_filter_strides);
array_convert(conv_filter_dilations_i64, conv_filter_dilations);
array_convert(input_left_pads_i64, input_left_pads);
array_convert(input_right_pads_i64, input_right_pads);
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i64,
a_g_n_c_wis_strides_i64,
b_g_k_c_xs_lengths_i64,
b_g_k_c_xs_strides_i64,
ds_g_n_k_wos_lengths_i64,
ds_g_n_k_wos_strides_i64,
e_g_n_k_wos_lengths_i64,
e_g_n_k_wos_strides_i64,
conv_filter_strides_i64,
conv_filter_dilations_i64,
input_left_pads_i64,
input_right_pads_i64,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths,
ds_g_n_k_wos_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CDEBlockTransferScalarPerVector_NPerBlock << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -57,7 +57,7 @@ struct DeviceImageToColumnImpl
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
using GemmToConvFwdTransformer =
using ConvToGemmFwdTransformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
static constexpr auto matrix_padder =
......@@ -97,7 +97,7 @@ struct DeviceImageToColumnImpl
b_g_k_c_xs_lengths[I2] = C;
c_g_n_k_wos_lengths[I1] = N;
GemmToConvFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
image_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -19,7 +19,7 @@ namespace device {
template <index_t Rank, int NumReduceDim>
std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>& inLengths)
{
static_assert(Rank <= 6, "bigger Rank size not supported!");
static_assert(Rank <= 12, "bigger Rank size not supported!");
long_index_t invariant_total_length = 1;
long_index_t reduce_total_length = 1;
......@@ -38,7 +38,7 @@ std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>&
template <index_t Rank, int NumReduceDim>
std::pair<long_index_t, long_index_t> get_2d_lengths(const std::array<index_t, Rank>& inLengths)
{
static_assert(Rank <= 6, "bigger Rank size not supported!");
static_assert(Rank <= 12, "bigger Rank size not supported!");
long_index_t invariant_total_length = 1;
long_index_t reduce_total_length = 1;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -51,7 +51,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InDataType,
PropagateNan,
OutputIndex>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(Rank <= 12, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -47,7 +47,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
OutputIndex>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(Rank <= 12, "Bigger Rank size is not supported!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
......
......@@ -45,7 +45,7 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduceMultiD<InDataType,
OutElementwiseOperation>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(Rank <= 12, "Bigger Rank size is not supported!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
......
......@@ -3,7 +3,6 @@
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
......@@ -107,6 +106,9 @@ struct TrinaryWithUnaryCombinedOp
UnaryOp2 unary_op2_{};
};
using ScaleScalePass = UnaryCombinedOp<Scale, Scale, PassThrough>;
using ScaleScaleRelu = UnaryCombinedOp<Scale, Scale, Relu>;
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -258,7 +258,7 @@ struct GridwiseMultipleReduction_mk_to_m_multiblock
if(thread_k_cluster_id == 0)
{
if(block_group_size == 0 && !float_equal_zero{}(beta_values[iR]))
if(!float_equal_zero{}(beta_values[iR]))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -244,7 +244,7 @@ struct GridwiseReduction_mk_to_m_multiblock
if(thread_k_cluster_id == 0)
{
if(block_group_size == 0 && !float_equal_zero{}(beta))
if(!float_equal_zero{}(beta))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -111,6 +111,15 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n)
{
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB &&
c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB))
{
return false;
}
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment