Commit bd689f40 authored by illsilin's avatar illsilin
Browse files

merge from public repo

parents c160c6cf a94113a9
......@@ -168,15 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// rotating mem
rotating_mem.Next();
// clear c mem
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(arg_.KBatch > 1)
hipGetErrorString(
hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
}
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
......@@ -190,14 +186,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
}
else
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
}
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
......@@ -215,15 +208,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
......@@ -240,118 +230,113 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
TailNumber::Two>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
......@@ -473,28 +458,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
......@@ -525,28 +507,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
......@@ -579,18 +558,14 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
......@@ -628,6 +603,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <typeinfo>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ReduceDataType = CDataType,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
ReduceDataType,
AElementwiseOperation,
BElementwiseOperation,
PassThrough,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
struct Argument : public GridwiseGemm::Argument
{
Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
const std::array<const void*, NumDTensor> p_ds_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<ck::index_t, NumDTensor> StrideDs_,
index_t StrideC_,
index_t k_batch_)
: GridwiseGemm::Argument(p_a_grid_,
p_b_grid_,
reinterpret_cast<ReduceDataType*>(p_c_grid_),
M_,
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
k_batch_,
true),
p_ds(p_ds_),
StrideDs(StrideDs_)
{
}
const std::array<const void*, NumDTensor> p_ds;
std::array<ck::index_t, NumDTensor> StrideDs;
};
using ReduceAdd = ck::reduce::Add;
using OutElementwiseOperation = CElementwiseOperation;
static constexpr auto DsVectorLengthSequence = generate_sequence_v2(
[](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same<CLayout, DLayout>::value)
return Number<CShuffleBlockTransferScalarPerVector_NPerBlock>{};
else
return Number<1>{};
},
Number<NumDTensor>{});
using DeviceReduceInstance = DeviceReduceThreadWiseMultiD<
ReduceDataType, // InDataType,
DsDataType, // DsDatatype
GemmAccDataType, // AccDataType,
CDataType, // OutDataType,
3, // Rank
1, // NumReduceDim
ReduceAdd,
PassThrough,
OutElementwiseOperation,
256, // BlockSize_,
CShuffleBlockTransferScalarPerVector_NPerBlock, // MThreadSliceSize_,
1, // KThreadSliceSize_,
0, // InSrcVectorDim_,
CShuffleBlockTransferScalarPerVector_NPerBlock, // InSrcVectorSize_,
CShuffleBlockTransferScalarPerVector_NPerBlock, // OutDstVectorSize_
decltype(DsVectorLengthSequence)>;
// Invoker
struct Invoker : public BaseInvoker
{
float RunReduce(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
static constexpr index_t NumInDim = 3;
static constexpr index_t NumOutDim = 2;
std::array<ck::index_t, NumInDim> in_lengths = {arg.KBatch, arg.M, arg.N};
std::array<ck::index_t, NumOutDim> out_lengths = {arg.M, arg.N};
std::array<ck::index_t, NumInDim> in_strides;
std::array<ck::index_t, NumOutDim> out_strides;
if constexpr(std::is_same<CLayout, ck::tensor_layout::gemm::RowMajor>::value)
{
in_strides = {arg.M * arg.N, arg.N, 1};
out_strides = {arg.N, 1};
}
else
{
in_strides = {arg.M * arg.N, 1, arg.M};
out_strides = {1, arg.M};
}
std::array<int, 1> reduce_dims{0};
std::array<std::array<index_t, NumOutDim>, NumDTensor> DsLengths;
std::array<std::array<index_t, NumOutDim>, NumDTensor> DsStrides;
static_for<0, NumDTensor, 1>{}([&](auto i) {
DsLengths[i] = out_lengths;
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same<DLayout, ck::tensor_layout::gemm::RowMajor>::value)
{
DsStrides[i] = {arg.StrideDs[i], 1};
}
else
{
DsStrides[i] = {1, arg.StrideDs[i]};
}
});
auto reduce = DeviceReduceInstance{};
auto argument_ptr = reduce.MakeArgumentPointer(in_lengths,
in_strides,
DsLengths,
DsStrides,
out_lengths,
out_strides,
reduce_dims,
arg.p_workspace_,
arg.p_ds,
arg.p_c_grid,
PassThrough{},
OutElementwiseOperation{});
auto invoker_ptr = reduce.MakeInvokerPointer();
float ave_time = 0;
if(reduce.IsSupportedArgument(argument_ptr.get()))
{
ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config);
}
else
{
throw std::runtime_error(
"The runtime parameters seems not supported by the device instance, exiting!");
}
return ave_time;
}
float Run(const Argument& arg_, const StreamConfig& stream_config = StreamConfig{})
{
auto arg = *dynamic_cast<const typename GridwiseGemm::Argument*>(&arg_);
if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
std::is_same<CDataType, ReduceDataType>::value))
{
if(arg.p_workspace_ == nullptr)
{
throw std::runtime_error("using reduce , but empty workspace!");
}
arg.p_c_grid = reinterpret_cast<ReduceDataType*>(arg.p_workspace_);
}
if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
ck::utility::RotatingMemWrapper<typename GridwiseGemm::Argument> rotating_mem(
arg,
stream_config.rotating_count,
arg.M * arg.K * sizeof(ADataType),
arg.K * arg.N * sizeof(BDataType));
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg);
}
else
{
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
std::is_same<CDataType, ReduceDataType>::value))
{
// reduce c data
ave_time += RunReduce(arg_, stream_config);
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
const std::array<const void*, NumDTensor> p_ds,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC,
index_t KBatch,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
{
return Argument{p_a, p_b, p_ds, p_c, M, N, K, StrideA, StrideB, StrideDs, StrideC, KBatch};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC,
index_t KBatch,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
KBatch);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemmXdlUniversalReduce"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerXDL<<"x"<<NPerXDL << ", "
<< "WaveMap: "
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
// clang-format on
return str.str();
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = *dynamic_cast<const Argument*>(p_arg);
if(!(!(arg.IsReduceAdd() || NumDTensor > 0) &&
std::is_same<CDataType, ReduceDataType>::value))
{
std::cout << "using workspace" << std::endl;
return arg.M * arg.N * arg.KBatch * sizeof(ReduceDataType);
}
return 0;
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -36,7 +36,7 @@ template <typename GridwiseGemm,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch,
index_t NumBatchToMerge,
index_t NumGroupsToMerge,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
......@@ -56,7 +56,7 @@ __global__ void
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset =
......@@ -92,7 +92,7 @@ template <typename GridwiseGemm,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch,
index_t NumBatchToMerge,
index_t NumGroupsToMerge,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
......@@ -113,7 +113,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__))
// offset base pointer for each work-group
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset =
......@@ -189,7 +189,7 @@ template <ck::index_t NDimSpatial,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
index_t NumBatchToMerge = 1,
index_t NumGroupsToMerge = 1,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
......@@ -238,7 +238,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
NPerBlock,
K1Number,
KPerBlock / K1Number,
NumBatchToMerge,
NumGroupsToMerge,
ConvBackwardWeightSpecialization>{};
static constexpr auto conv_to_gemm_transformer_v1 =
......@@ -638,7 +638,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumBatchToMerge);
gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumGroupsToMerge);
float ave_time = 0;
......@@ -724,7 +724,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
......@@ -739,7 +739,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
......@@ -760,7 +760,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -777,7 +777,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -796,7 +796,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -817,7 +817,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -838,7 +838,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -859,7 +859,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -879,7 +879,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -900,7 +900,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -920,7 +920,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -937,7 +937,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -956,7 +956,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -977,7 +977,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -998,7 +998,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -1019,7 +1019,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -1039,7 +1039,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -1060,7 +1060,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -1084,7 +1084,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -1100,7 +1100,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -1119,7 +1119,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -1135,7 +1135,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -1157,7 +1157,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -1173,7 +1173,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
......@@ -1192,7 +1192,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -1208,7 +1208,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
......@@ -1232,7 +1232,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
......@@ -1247,7 +1247,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
NumBatchToMerge,
NumGroupsToMerge,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
......@@ -1389,7 +1389,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
}
}
if constexpr(NumBatchToMerge > 1)
if constexpr(NumGroupsToMerge > 1)
{
// support only if whole M and N can be proccessed on one block
if(!(GemmM <= MPerBlock && GemmN <= NPerBlock))
......@@ -1400,7 +1400,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
return false;
}
if(arg.Conv_G_ % NumBatchToMerge != 0)
if(arg.Conv_G_ % NumGroupsToMerge != 0)
{
return false;
}
......@@ -1563,7 +1563,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< NumBatchToMerge
<< NumGroupsToMerge
<< ">";
// clang-format on
......
......@@ -238,37 +238,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_g_n_c_wis_lengths[I1]);
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -286,12 +266,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
template <typename BLay>
static auto
MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
......@@ -309,13 +287,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
}
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -323,27 +298,27 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
return out_gemmm_gemmn_desc;
}
static auto MakeDsGridDescriptor_M_N(
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(ds_g_n_k_wos_lengths[i],
ds_g_n_k_wos_strides[i]);
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
},
Number<NumDTensor>{});
}
// desc for problem definition
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
dummy_conv_to_gemm_transformer))>;
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
// GridwiseGemm
using GridwiseGemm =
......@@ -426,21 +401,22 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(conv_to_gemm_transformer_)},
e_grid_desc_m_n_{
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{},
ds_grid_desc_m0_m10_m11_n0_n10_n11_{},
......@@ -471,6 +447,17 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths[i],
ds_g_n_k_wos_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// D pointer
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
......@@ -478,8 +465,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
// D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]);
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
});
// populate desc for Ds/E
......@@ -523,6 +510,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// tensor descriptors for problem definiton
index_t num_group_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
DsGridDesc_M_N ds_grid_desc_m_n_;
......@@ -846,6 +836,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
cde_element_op};
}
static auto
MakeArgument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
......@@ -890,6 +953,79 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
cde_element_op);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
......
......@@ -234,37 +234,17 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
c_g_n_k_wos_lengths,
c_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_g_n_c_wis_lengths[I1]);
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -283,12 +263,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
template <typename BLay>
static auto
MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
......@@ -306,13 +284,10 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
}
template <typename CLay>
static auto
MakeCGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
static auto MakeCGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(
c_g_n_k_wos_lengths, c_g_n_k_wos_strides, c_g_n_k_wos_lengths[I1]);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -321,11 +296,13 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
}
// desc for problem definition
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N<CLayout>({}, {}))>;
dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
dummy_conv_to_gemm_transformer))>;
using CGridDesc_M_N =
remove_cvref_t<decltype(MakeCGridDescriptor_M_N<CLayout>(dummy_conv_to_gemm_transformer))>;
// GridwiseGemm
using GridwiseGemm =
......@@ -396,21 +373,22 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
p_b_grid_{static_cast<const BDataType*>(p_b)},
p_c_grid_{static_cast<CDataType*>(p_c)},
num_group_{a_g_n_c_wis_lengths[0]},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
c_g_n_k_wos_lengths,
c_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N<CLayout>(c_g_n_k_wos_lengths,
c_g_n_k_wos_strides)},
DeviceOp::MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1<BLayout>(conv_to_gemm_transformer_)},
c_grid_desc_m_n_{
DeviceOp::MakeCGridDescriptor_M_N<CLayout>(conv_to_gemm_transformer_)},
a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n0_n1_k1_{},
c_grid_desc_m0_m10_m11_n0_n10_n11_{},
......@@ -473,6 +451,9 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
// tensor descriptors for problem definiton
index_t num_group_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
......
......@@ -86,7 +86,6 @@ __global__ void
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t groups_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -101,10 +100,8 @@ __global__ void
defined(__gfx94__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
......@@ -200,7 +197,6 @@ __global__ void
ignore = p_bs_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = groups_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
......@@ -318,38 +314,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
EDataType>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t Conv_N)
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
Conv_N);
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -358,13 +336,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
template <typename BLay>
static auto
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
......@@ -373,14 +348,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const index_t Conv_N)
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -390,27 +361,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast.
static auto MakeDsGridDescriptor_M_N(
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const index_t Conv_N)
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], Conv_N);
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
},
Number<NumDTensor>{});
}
// desc for problem definition
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, 1))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}, 1))>;
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K =
remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it
......@@ -498,28 +469,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
conv_N_per_block_{
conv_to_gemm_transformer.template GetSplitedNSize<ADataType, EDataType>(
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
conv_N_per_block_)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides)},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
conv_N_per_block_{conv_to_gemm_transformer_.N_},
a_grid_desc_m_k_{
DeviceOp::MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_n_k_{
DeviceOp::MakeBGridDescriptor_N_K<BLayout>(conv_to_gemm_transformer_)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)},
e_grid_desc_m_n_{
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
......@@ -620,9 +587,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
ds_g_n_k_wos_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], conv_N_per_block_);
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
});
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
......@@ -687,6 +665,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton
index_t num_group_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
index_t conv_N_per_block_;
AGridDesc_M_K a_grid_desc_m_k_;
......@@ -745,8 +726,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const index_t gdy = arg.num_group_ * num_workgroups_per_Conv_N;
const index_t gdz = 1;
const index_t gdy = arg.num_group_;
const index_t gdz = num_workgroups_per_Conv_N;
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......@@ -795,7 +776,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
......@@ -839,7 +819,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
......@@ -1103,11 +1082,84 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
cde_element_op};
}
static auto
MakeArgument(APointers p_as,
BPointers p_bs,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
APointers p_a,
BPointers p_b,
APointers p_as,
BPointers p_bs,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
......@@ -1126,8 +1178,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
return std::make_unique<Argument>(p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths,
......@@ -1147,6 +1199,80 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
cde_element_op);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(APointers p_as,
BPointers p_bs,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return std::make_unique<Argument>(p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
......
......@@ -293,39 +293,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
EDataType>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t Conv_N)
MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
Conv_N);
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -344,12 +327,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <typename BLay>
static auto
MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
......@@ -367,15 +348,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const index_t Conv_N)
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -384,7 +361,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
// desc for problem definition
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}, 1))>;
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
#define GridwiseGemmV3TemplateParams \
tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \
......@@ -417,9 +396,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
dummy_conv_to_gemm_transformer))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>(
dummy_conv_to_gemm_transformer))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
......@@ -450,27 +429,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
p_b_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
conv_N_per_block_{
conv_to_gemm_transformer.template GetSplitedNSize<ADataType, EDataType>(
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_ak0_m_ak1_{MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
conv_N_per_block_)},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
conv_N_per_block_{conv_to_gemm_transformer_.N_},
a_grid_desc_ak0_m_ak1_{
MakeAGridDescriptor_AK0_M_AK1<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_bk0_n_bk1_{
MakeBGridDescriptor_BK0_N_BK1<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)},
MakeBGridDescriptor_BK0_N_BK1<BLayout>(conv_to_gemm_transformer_)},
e_grid_desc_m_n_{
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{},
......@@ -519,6 +494,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// tensor descriptors for problem definiton
index_t num_group_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
index_t conv_N_per_block_;
// tensor descriptors for block/thread-wise copy
......@@ -1000,6 +978,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return false;
}
// Gridwise gemm v3 doesn't verify descriptors size
if(!arg.conv_to_gemm_transformer_.AreDescriptorsSmallerThan2GB())
{
return false;
}
// check Gridwise GEMM
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1);
......@@ -1059,6 +1043,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
cde_element_op};
}
static auto
MakeArgument(const void* p_as,
const void* p_bs,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_as,
p_bs,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
......@@ -1103,6 +1160,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
cde_element_op);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
......
......@@ -309,37 +309,16 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_g_n_c_wis_lengths[I1]);
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -348,13 +327,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
}
template <typename BLay>
static auto
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
......@@ -363,13 +339,10 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
}
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -447,11 +420,14 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo);
}
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<DELayout>({}, {}))>;
using RGridDesc_M = remove_cvref_t<decltype(MakeRGridDescriptor_M<RLayout>({}, {}))>;
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K =
remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<DELayout>(dummy_conv_to_gemm_transformer))>;
using RGridDesc_M = remove_cvref_t<decltype(MakeRGridDescriptor_M<RLayout>({}, {}))>;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
......@@ -551,21 +527,23 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
p_rs_grid_{}, // FIXME
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides)},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
a_grid_desc_m_k_{
DeviceOp::MakeAGridDescriptor_M_K<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_n_k_{
DeviceOp::MakeBGridDescriptor_N_K<BLayout>(conv_to_gemm_transformer_)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<DELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
e_grid_desc_m_n_{
DeviceOp::MakeEGridDescriptor_M_N<DELayout>(conv_to_gemm_transformer_)},
r_grid_desc_m_{
DeviceOp::MakeRGridDescriptor_M<RLayout>(r_g_n_wos_lengths, r_g_n_wos_strides)},
a_grid_desc_ak0_m_ak1_{
......@@ -621,9 +599,20 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
// D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths[i],
ds_g_n_k_wos_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DELayout>(
ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]);
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DELayout>(conv_to_gemm_transformer_d);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
......@@ -660,6 +649,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
EDataType* p_e_grid_;
typename GridwiseGemm::RsGridPointer p_rs_grid_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
......
......@@ -135,36 +135,16 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static constexpr auto BEnableLds =
BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1);
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
using ConvToGemmFwdTransformer = TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto MakeAGridDescriptor(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
static auto MakeAGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_g_n_c_wis_lengths[I1]);
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -205,12 +185,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
template <typename BLay>
static auto MakeBGridDescriptor(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
static auto MakeBGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides);
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
......@@ -251,13 +229,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
}
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -265,26 +240,27 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
return out_gemmm_gemmn_desc;
}
static auto MakeDsGridDescriptor_M_N(
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(ds_g_n_k_wos_lengths[i],
ds_g_n_k_wos_strides[i]);
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
},
Number<NumDTensor>{});
}
// desc for problem definition
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer;
using AGridDesc =
decltype(DeviceOp::MakeAGridDescriptor<ALayout>({}, {}, {}, {}, {}, {}, {}, {}, {}, {}));
using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor<BLayout>({}, {}));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
decltype(DeviceOp::MakeAGridDescriptor<ALayout>(dummy_conv_to_gemm_transformer));
using BGridDesc =
decltype(DeviceOp::MakeBGridDescriptor<BLayout>(dummy_conv_to_gemm_transformer));
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
// GridwiseOp
using GridwiseOp = GridwiseGemmMultipleD_Wmma<
......@@ -373,21 +349,21 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_{DeviceOp::MakeAGridDescriptor<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
b_grid_desc_{
DeviceOp::MakeBGridDescriptor<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
e_grid_desc_m_n_{
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_)},
a_grid_desc_{DeviceOp::MakeAGridDescriptor<ALayout>(conv_to_gemm_transformer_)},
b_grid_desc_{DeviceOp::MakeBGridDescriptor<BLayout>(conv_to_gemm_transformer_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)},
......@@ -426,8 +402,24 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
});
// D desc
ds_grid_desc_m_n_ =
DeviceOp::MakeDsGridDescriptor_M_N(ds_g_n_k_wos_lengths, ds_g_n_k_wos_strides);
ds_grid_desc_m_n_ = generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths[i],
ds_g_n_k_wos_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
},
Number<NumDTensor>{});
// populate desc for Ds/E
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
......@@ -455,6 +447,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
// tensor descriptors for problem definiton
index_t num_group_;
ConvToGemmFwdTransformer conv_to_gemm_transformer_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
......@@ -777,6 +772,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
cde_element_op};
}
static auto
MakeArgument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return Argument{p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
1,
1,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
......@@ -823,6 +893,81 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
cde_element_op);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_strides_i32;
std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
std::array<index_t, NDimSpatial> input_left_pads_i32;
std::array<index_t, NDimSpatial> input_right_pads_i32;
array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i32, conv_filter_strides);
array_convert(conv_filter_dilations_i32, conv_filter_dilations);
array_convert(input_left_pads_i32, input_left_pads);
array_convert(input_right_pads_i32, input_right_pads);
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i32,
a_g_n_c_wis_strides_i32,
b_g_k_c_xs_lengths_i32,
b_g_k_c_xs_strides_i32,
ds_g_n_k_wos_lengths_i32,
ds_g_n_k_wos_strides_i32,
e_g_n_k_wos_lengths_i32,
e_g_n_k_wos_strides_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
1,
1,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <queue>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace {
template <typename GridwiseGemm,
index_t MaxGemmsNum,
typename GemmArgs,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename ComputePtrOffset,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle(
Array<GemmArgs, MaxGemmsNum> gemm_desc_kernel_args,
const index_t gemms_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op,
const ComputePtrOffset compute_ptr_offset_of_groups,
const ComputePtrOffset compute_ptr_offset_of_n)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const long_index_t a_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t e_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
index_t left = 0;
index_t right = gemms_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_id_x >= gemm_desc_kernel_args[group_id].BlockStart_ &&
block_id_x < gemm_desc_kernel_args[group_id].BlockEnd_)) &&
left <= right)
{
if(block_id_x < gemm_desc_kernel_args[group_id].BlockStart_)
{
right = group_id;
}
else
{
left = group_id;
}
group_id = index_t((left + right) / 2);
}
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset,
gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset,
Tuple<>{},
gemm_desc_kernel_args[group_id].e_ptr_ + e_group_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_desc_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
Tuple<>{},
gemm_desc_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_kernel_args[group_id].block_2_etile_map_);
#else
ignore = gemm_desc_kernel_args;
ignore = gemms_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
ignore = compute_ptr_offset_of_groups;
ignore = compute_ptr_offset_of_n;
#endif
}
} // namespace
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
ConvolutionForwardSpecialization ConvForwardSpecialization,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename AComputeDataType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>,
ADataType>()), // ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
typename BComputeDataType = AComputeDataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
AComputeDataType,
BComputeDataType>
{
using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t MaxGemmsNum = 32;
static_assert(NumDTensor == 0, "MultiD not supported.");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ConvToGemmFwdTransformerIndexT = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
EDataType,
I1,
index_t>;
using ConvToGemmFwdTransformerLongIndexT = TransformConvFwdToGemm<NDimSpatial,
ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
EDataType,
I1,
long_index_t>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
template <typename ALay>
static auto
MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
return in_gemmm_gemmk_desc;
}
template <typename BLay>
static auto
MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer)
{
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
return wei_gemmn_gemmk_desc;
}
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
return out_gemmm_gemmn_desc;
}
// desc for problem definition
constexpr static ConvToGemmFwdTransformerIndexT dummy_conv_to_gemm_transformer;
using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K =
remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
static auto
GenerateConvToGemmTransforms(ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformer_base,
const ADataType* a_grid_ptr_base,
EDataType* c_grid_ptr_base)
{
// Max number of splits
// We need to use it to avoid infinity loop
constexpr index_t max_split_numbers = MaxGemmsNum / 2;
// Arrays to store transformers with smaller descs than 2GB
Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformers_arr;
Array<const ADataType*, MaxGemmsNum> a_grid_ptrs_arr;
Array<EDataType*, MaxGemmsNum> c_grid_ptrs_arr;
// Queue for spliting
std::queue<ConvToGemmFwdTransformerLongIndexT> conv_to_gemm_transformers_queue(
{conv_to_gemm_transformer_base});
std::queue<const ADataType*> a_grid_ptrs_queue({a_grid_ptr_base});
std::queue<EDataType*> c_grid_ptrs_queue({c_grid_ptr_base});
index_t gemms_number = 0;
index_t split_numbers = 0;
// Algorithm:
// While queue is not empty:
// 1. Get transformer from queue.
// 2. If descs are smaller than 2GB push to result array.
// 3. If descs are bigger than 2GB split into left and right transformer.
// and push the both into the queue.
while(!conv_to_gemm_transformers_queue.empty() && split_numbers < max_split_numbers &&
gemms_number < MaxGemmsNum)
{
// Get transformer from the queue
const auto& conv_to_gemm_transformer = conv_to_gemm_transformers_queue.front();
const ADataType* a_grid_ptr = a_grid_ptrs_queue.front();
EDataType* c_grid_ptr = c_grid_ptrs_queue.front();
// Check if convolution not exceed 2GB
if(conv_to_gemm_transformer.AreDescriptorsSmallerThan2GB())
{
// If yes, push into result array
conv_to_gemm_transformers_arr(gemms_number) =
ConvToGemmFwdTransformerIndexT{conv_to_gemm_transformer};
a_grid_ptrs_arr(gemms_number) = a_grid_ptr;
c_grid_ptrs_arr(gemms_number) = c_grid_ptr;
gemms_number++;
}
else
{
// If no, split into left and right convolutions
ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformers_left_part,
conv_to_gemm_transformers_right_part;
const ADataType* a_grid_right_ptr;
EDataType* c_grid_right_ptr;
ck::tie(conv_to_gemm_transformers_left_part,
conv_to_gemm_transformers_right_part,
a_grid_right_ptr,
c_grid_right_ptr) =
conv_to_gemm_transformer.SplitConvProblem(a_grid_ptr, c_grid_ptr);
conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_left_part);
conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_right_part);
// Left offsets remain the same
a_grid_ptrs_queue.push(a_grid_ptr);
a_grid_ptrs_queue.push(a_grid_right_ptr);
c_grid_ptrs_queue.push(c_grid_ptr);
c_grid_ptrs_queue.push(c_grid_right_ptr);
split_numbers++;
}
// Remove from the queue
conv_to_gemm_transformers_queue.pop();
a_grid_ptrs_queue.pop();
c_grid_ptrs_queue.pop();
}
const bool is_split_valid = conv_to_gemm_transformers_queue.empty();
return ck::make_tuple(conv_to_gemm_transformers_arr,
a_grid_ptrs_arr,
c_grid_ptrs_arr,
gemms_number,
is_split_valid);
}
#define GridwiseGemmTemplateParameters \
ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
AComputeDataType
// Use appropriate gridwise gemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Structure for each gemm(conv)
struct GemmArgs
{
// pointers
const ADataType* a_ptr_;
const BDataType* b_ptr_;
EDataType* e_ptr_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
ck::index_t BlockStart_, BlockEnd_;
};
// Argument
struct Argument : public BaseArgument
{
Argument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& /*p_ds*/,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
/*ds_g_n_k_wos_lengths*/,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
/*ds_g_n_k_wos_strides*/,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
: num_group_{static_cast<index_t>(a_g_n_c_wis_lengths[0])},
compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
// Perform grouped gemm, generate array of tranformer for convolution
Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformer_arr;
Array<const ADataType*, MaxGemmsNum> a_grid_ptrs;
Array<EDataType*, MaxGemmsNum> c_grid_ptrs;
ck::tie(conv_to_gemm_transformer_arr,
a_grid_ptrs,
c_grid_ptrs,
gemms_count_,
is_split_valid_) =
GenerateConvToGemmTransforms(
ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_,
a_g_n_c_wis_strides_,
b_g_k_c_xs_lengths_,
b_g_k_c_xs_strides_,
e_g_n_k_wos_lengths_,
e_g_n_k_wos_strides_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_},
static_cast<const ADataType*>(p_a),
static_cast<EDataType*>(p_e));
grid_size_ = 0;
valid_gemms_count_ = 0;
if(is_split_valid_)
{
// Create GemmArg for each gemm(conv)
for(index_t i = 0; i < gemms_count_; i++)
{
const AGridDesc_M_K a_grid_desc_m_k{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(
conv_to_gemm_transformer_arr[i])};
const BGridDesc_N_K b_grid_desc_n_k{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(
conv_to_gemm_transformer_arr[i])};
const auto e_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_arr[i]);
const auto block_2_etile_map =
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
const index_t grid_size_grp =
block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
Tuple<>{},
e_grid_desc_m_n,
block_2_etile_map))
{
gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{
a_grid_ptrs[i],
static_cast<const BDataType*>(p_b),
c_grid_ptrs[i],
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k),
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k),
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n),
block_2_etile_map,
BlockStart,
BlockEnd};
valid_gemms_count_++;
}
}
// N is the same for all convs
conv_N_per_block_ = static_cast<index_t>(conv_to_gemm_transformer_arr[I0].N_);
}
// Strides for G and N remain the same
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
}
void Print() const
{
for(index_t i = 0; i < valid_gemms_count_; i++)
{
std::cout << "A[AK0, M, AK1]: " << gemm_desc_kernel_args_[i].a_grid_desc_ak0_m_ak1_
<< std::endl;
std::cout << "B[BK0, N, BK1]: " << gemm_desc_kernel_args_[i].b_grid_desc_bk0_n_bk1_
<< std::endl;
std::cout
<< "E[MBlock, MPerBlock, NBlock, NPerBlock]: "
<< gemm_desc_kernel_args_[i].e_grid_desc_mblock_mperblock_nblock_nperblock_
<< std::endl;
}
}
index_t num_group_;
index_t conv_N_per_block_;
Array<GemmArgs, MaxGemmsNum> gemm_desc_kernel_args_;
index_t grid_size_;
index_t gemms_count_;
index_t valid_gemms_count_;
bool is_split_valid_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_groups_;
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
// element-wise op
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// for checking IsSupportedArgument()
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
std::array<long_index_t, NDimSpatial> conv_filter_strides_;
std::array<long_index_t, NDimSpatial> conv_filter_dilations_;
std::array<long_index_t, NDimSpatial> input_left_pads_;
std::array<long_index_t, NDimSpatial> input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const DeviceOp::Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdx = arg.grid_size_;
const index_t gdy = arg.num_group_;
const index_t gdz = num_workgroups_per_Conv_N;
// K is constant for all gemms
const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle<
GridwiseGemm,
MaxGemmsNum,
GemmArgs,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.gemm_desc_kernel_args_,
arg.gemms_count_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
namespace ctc = tensor_layout::convolution;
const long_index_t K = arg.b_g_k_c_xs_lengths_[I1];
const long_index_t C = arg.b_g_k_c_xs_lengths_[I2];
// Check if all descs are valid
if(!(arg.is_split_valid_ && arg.gemms_count_ == arg.valid_gemms_count_))
{
return false;
}
// check device
if(get_device_name() == "gfx908")
{
// FIXME: re-enable fp64 when SWDEV-335738 is fixed
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
return false;
}
}
if(!ck::is_xdl_supported())
{
return false;
}
// check ConvolutionForwardSpecialization
if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
{
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
}
}
else if constexpr(ConvForwardSpecialization ==
ConvolutionForwardSpecialization::Filter1x1Pad0)
{
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
}
}
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3)
{
if(C != 1)
{
return false;
}
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3];
if(filter_spatial_dim != I3)
{
return false;
}
}
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
{
return false;
}
}
// check vector access of A
// FIXME: layout
if constexpr(is_same_v<ALayout, ctc::G_NW_C> || is_same_v<ALayout, ctc::G_NHW_C> ||
is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC>)
{
// Check access per C
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// check vector access of B
// FIXME: layout
if constexpr(is_same_v<BLayout, ctc::G_K_X_C> || is_same_v<BLayout, ctc::G_K_YX_C> ||
is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
is_same_v<BLayout, ctc::KZYXGC>)
{
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
return false;
}
// check vector access of E
if constexpr(is_same_v<ELayout, ctc::G_NW_K> || is_same_v<ELayout, ctc::G_NHW_K> ||
is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
is_same_v<ELayout, ctc::NDHWGK>)
{
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
}
else
{
return false;
}
return true;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(
const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i64;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i64;
std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i64;
std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i64;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i64;
std::array<long_index_t, NDimSpatial> conv_filter_strides_i64;
std::array<long_index_t, NDimSpatial> conv_filter_dilations_i64;
std::array<long_index_t, NDimSpatial> input_left_pads_i64;
std::array<long_index_t, NDimSpatial> input_right_pads_i64;
array_convert(a_g_n_c_wis_lengths_i64, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i64, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i64, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i64, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i64[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i64[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i64, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i64, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i64, conv_filter_strides);
array_convert(conv_filter_dilations_i64, conv_filter_dilations);
array_convert(input_left_pads_i64, input_left_pads);
array_convert(input_right_pads_i64, input_right_pads);
return Argument{p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i64,
a_g_n_c_wis_strides_i64,
b_g_k_c_xs_lengths_i64,
b_g_k_c_xs_strides_i64,
ds_g_n_k_wos_lengths_i64,
ds_g_n_k_wos_strides_i64,
e_g_n_k_wos_lengths_i64,
e_g_n_k_wos_strides_i64,
conv_filter_strides_i64,
conv_filter_dilations_i64,
input_left_pads_i64,
input_right_pads_i64,
a_element_op,
b_element_op,
cde_element_op};
}
static auto
MakeArgument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths,
ds_g_n_k_wos_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i64;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i64;
std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i64;
std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i64;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i64;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i64;
std::array<long_index_t, NDimSpatial> conv_filter_strides_i64;
std::array<long_index_t, NDimSpatial> conv_filter_dilations_i64;
std::array<long_index_t, NDimSpatial> input_left_pads_i64;
std::array<long_index_t, NDimSpatial> input_right_pads_i64;
array_convert(a_g_n_c_wis_lengths_i64, a_g_n_c_wis_lengths);
array_convert(a_g_n_c_wis_strides_i64, a_g_n_c_wis_strides);
array_convert(b_g_k_c_xs_lengths_i64, b_g_k_c_xs_lengths);
array_convert(b_g_k_c_xs_strides_i64, b_g_k_c_xs_strides);
for(index_t d = 0; d < NumDTensor; d++)
{
array_convert(ds_g_n_k_wos_lengths_i64[d], ds_g_n_k_wos_lengths[d]);
array_convert(ds_g_n_k_wos_strides_i64[d], ds_g_n_k_wos_strides[d]);
}
array_convert(e_g_n_k_wos_lengths_i64, e_g_n_k_wos_lengths);
array_convert(e_g_n_k_wos_strides_i64, e_g_n_k_wos_strides);
array_convert(conv_filter_strides_i64, conv_filter_strides);
array_convert(conv_filter_dilations_i64, conv_filter_dilations);
array_convert(input_left_pads_i64, input_left_pads);
array_convert(input_right_pads_i64, input_right_pads);
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths_i64,
a_g_n_c_wis_strides_i64,
b_g_k_c_xs_lengths_i64,
b_g_k_c_xs_strides_i64,
ds_g_n_k_wos_lengths_i64,
ds_g_n_k_wos_strides_i64,
e_g_n_k_wos_lengths_i64,
e_g_n_k_wos_strides_i64,
conv_filter_strides_i64,
conv_filter_dilations_i64,
input_left_pads_i64,
input_right_pads_i64,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
const std::array<long_index_t, NDimSpatial>& input_left_pads,
const std::array<long_index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_k_wos_lengths,
ds_g_n_k_wos_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CDEBlockTransferScalarPerVector_NPerBlock << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -59,6 +59,22 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NSpatialGK_GKSpatial_NSpatialGC()
{
return is_NWGK_GKXC_NWGC<InLayout, WeiLayout, OutLayout>() ||
is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>();
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNSpatialK_GKSpatial_GNSpatialC()
{
return is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>();
}
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
struct ComputePtrOffsetOfStridedBatch
{
......
......@@ -57,8 +57,8 @@ struct DeviceImageToColumnImpl
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>{};
using ConvToGemmFwdTransformer =
TransformConvFwdToGemm<NDimSpatial, ConvolutionForwardSpecialization::Default>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpecialization::MKPadding, index_t, index_t, index_t>{
......@@ -97,19 +97,19 @@ struct DeviceImageToColumnImpl
b_g_k_c_xs_lengths[I2] = C;
c_g_n_k_wos_lengths[I1] = N;
ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
image_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor
c_g_n_k_wos_lengths,
{}, // not needed for A Descriptor
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>(
a_g_n_c_wis_lengths,
image_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
{}, // not needed for A Descriptor
c_g_n_k_wos_lengths,
{}, // not needed for A Descriptor
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
N);
conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>();
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -19,7 +19,7 @@ namespace device {
template <index_t Rank, int NumReduceDim>
std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>& inLengths)
{
static_assert(Rank <= 6, "bigger Rank size not supported!");
static_assert(Rank <= 12, "bigger Rank size not supported!");
long_index_t invariant_total_length = 1;
long_index_t reduce_total_length = 1;
......@@ -38,7 +38,7 @@ std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>&
template <index_t Rank, int NumReduceDim>
std::pair<long_index_t, long_index_t> get_2d_lengths(const std::array<index_t, Rank>& inLengths)
{
static_assert(Rank <= 6, "bigger Rank size not supported!");
static_assert(Rank <= 12, "bigger Rank size not supported!");
long_index_t invariant_total_length = 1;
long_index_t reduce_total_length = 1;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -51,7 +51,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InDataType,
PropagateNan,
OutputIndex>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(Rank <= 12, "Bigger Rank size is not supported!");
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"Invalid thread cluster size assignments!");
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -47,7 +47,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
OutputIndex>
{
static_assert(Rank <= 6, "Bigger Rank size is not supported!");
static_assert(Rank <= 12, "Bigger Rank size is not supported!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <array>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataType,
typename DsDataType,
typename AccDataType,
typename OutDataType,
index_t Rank,
index_t NumReduceDim,
typename ReduceOperation,
typename InElementwiseOperation,
typename OutElementwiseOperation,
index_t BlockSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
index_t InSrcVectorDim,
index_t InSrcVectorSize,
index_t OutDstVectorSize,
typename DsVectorSizeSequence>
struct DeviceReduceThreadWiseMultiD : public DeviceReduceMultiD<InDataType,
DsDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
OutElementwiseOperation>
{
static_assert(Rank <= 12, "Bigger Rank size is not supported!");
static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
(InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
(MThreadSliceSize % OutDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
using IndexDataType = int32_t;
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t NumSrcDim = Rank;
static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
const std::array<index_t, Rank>& inStrides)
{
const auto tupleSrcLengths =
generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
const auto tupleSrcStrides =
generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto in_grid_desc_m_k = [&]() {
if constexpr(reduceAllDim)
{
const auto one_dim_inDesc = transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(tupleSrcLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}),
make_tuple(Sequence<0>{}));
return transform_tensor_descriptor(one_dim_inDesc,
make_tuple(make_unmerge_transform(make_tuple(
1, one_dim_inDesc.GetLength(Number<0>{})))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
else
{
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths = generate_tuple(
[&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
const auto invariantDimLengths =
generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
return transform_tensor_descriptor(
inDesc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}();
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K =
math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, inPad_M),
make_right_pad_transform(reduceLength, inPad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return (in_grid_desc_m_k_padded);
};
static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
const std::array<index_t, NumDstDim>& outStrides)
{
const auto tupleDstLengths =
generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
const auto tupleDstStrides =
generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
auto out_grid_desc_m = transform_tensor_descriptor(
outDesc,
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
const auto outPad =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
auto out_grid_desc_m_padded = transform_tensor_descriptor(
out_grid_desc_m,
make_tuple(make_right_pad_transform(invariantLength, outPad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return (out_grid_desc_m_padded);
};
static auto
MakeDsDescriptor(const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides)
{
return generate_tuple(
[&](auto i) {
return DeviceReduceThreadWiseMultiD::MakeDst1dDescriptor(DsLengths[i],
DsStrides[i]);
},
Number<NumDTensor>{});
}
using InGridDesc_M_K = decltype(MakeSrc2dDescriptor({}, {}));
using OutGridDesc_M = decltype(MakeDst1dDescriptor({}, {}));
using DsGridDesc_M = decltype(MakeDsDescriptor({}, {}));
using GridwiseReduce =
GridwiseReduction_mk_to_m_threadwise_multi_d<InDataType,
DsDataType,
OutDataType,
AccDataType,
InGridDesc_M_K,
DsGridDesc_M,
OutGridDesc_M,
ReduceOperation,
InElementwiseOperation,
OutElementwiseOperation,
InMemoryDataOperationEnum::Set,
BlockSize,
MThreadSliceSize,
KThreadSliceSize,
InSrcVectorDim,
InSrcVectorSize,
OutDstVectorSize,
DsVectorSizeSequence>;
using DsGridPointer = typename GridwiseReduce::DsGridPointer;
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides,
const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides,
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
const InDataType* in_dev,
const std::array<const void*, NumDTensor> ds_dev,
OutDataType* out_dev,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation out_elementwise_op)
: DsLengths_{DsLengths},
DsStrides_{DsStrides},
outLengths_{outLengths},
outStrides_{outStrides},
in_dev_{in_dev},
out_dev_{out_dev},
in_elementwise_op_{in_elementwise_op},
out_elementwise_op_{out_elementwise_op}
{
inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
std::tie(invariant_total_length, reduce_total_length) =
get_2d_lengths<Rank, NumReduceDim>(inLengths_);
if constexpr(NumInvariantDim == 0)
invariant_lowest_length = 1;
else
invariant_lowest_length = inLengths_[NumInvariantDim - 1];
reduce_lowest_length = inLengths_[Rank - 1];
numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_(i) = static_cast<const DDataType*>(ds_dev[i]);
});
ds_grid_desc_m_ = MakeDsDescriptor(DsLengths, DsStrides);
}
std::array<index_t, Rank> inLengths_;
std::array<index_t, Rank> inStrides_;
std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths_;
std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides_;
std::array<index_t, NumDstDim> outLengths_;
std::array<index_t, NumDstDim> outStrides_;
const InDataType* in_dev_;
OutDataType* out_dev_;
DsGridPointer p_ds_grid_;
InElementwiseOperation in_elementwise_op_;
OutElementwiseOperation out_elementwise_op_;
DsGridDesc_M ds_grid_desc_m_;
index_t invariant_lowest_length;
index_t reduce_lowest_length;
long_index_t invariant_total_length;
long_index_t reduce_total_length;
int numBlockTileIteration;
size_t gridSize;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto in_grid_desc_m_k =
DeviceReduceThreadWiseMultiD::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_);
const auto out_grid_desc_m =
DeviceReduceThreadWiseMultiD::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_);
float avg_time = 0;
const auto kernel = kernel_reduce_threadwise_multi_d<GridwiseReduce,
InDataType,
OutDataType,
AccDataType,
InGridDesc_M_K,
DsGridDesc_M,
OutGridDesc_M,
InElementwiseOperation,
OutElementwiseOperation,
DsGridPointer>;
avg_time = launch_and_time_kernel(stream_config,
kernel,
dim3(arg.gridSize),
dim3(BlockSize),
0,
in_grid_desc_m_k,
arg.ds_grid_desc_m_,
out_grid_desc_m,
arg.in_elementwise_op_,
arg.out_elementwise_op_,
arg.in_dev_,
arg.p_ds_grid_,
arg.out_dev_);
return (avg_time);
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if constexpr(InSrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
{
return (false);
}
else
{
if(pArg->inStrides_[NumInvariantDim - 1] != 1)
return (false);
if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
return (false);
};
}
else
{
if(pArg->inStrides_[Rank - 1] != 1)
return (false);
if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
return (false);
};
// To improve
if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
return (false);
std::cerr << "reduce_total_length = " << pArg->reduce_total_length
<< " KThreadSliceSize = " << KThreadSliceSize << std::endl;
// cases with big reduce_total_length should be handled by Blockwise kernel
if(pArg->reduce_total_length / KThreadSliceSize >= 32)
return (false);
return (true);
};
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
const std::array<index_t, Rank> inStrides,
const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsLengths,
const std::array<std::array<index_t, NumDstDim>, NumDTensor> DsStrides,
const std::array<index_t, NumDstDim> outLengths,
const std::array<index_t, NumDstDim> outStrides,
const std::array<int, NumReduceDim> reduceDims,
const void* in_dev,
const std::array<const void*, NumDTensor> ds_dev,
void* out_dev,
const InElementwiseOperation in_elementwise_op,
const OutElementwiseOperation out_elementwise_op) override
{
return std::make_unique<Argument>(inLengths,
inStrides,
DsLengths,
DsStrides,
outLengths,
outStrides,
reduceDims,
static_cast<const InDataType*>(in_dev),
ds_dev,
static_cast<OutDataType*>(out_dev),
in_elementwise_op,
out_elementwise_op);
};
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceReduceThreadWiseMultiD<" << BlockSize << ",";
str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ",";
str << "K_C" << 1 << "_S" << KThreadSliceSize << ",";
str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -638,6 +638,32 @@ struct AddSilu
}
};
struct ConvScaleAdd
{
__host__ __device__ ConvScaleAdd(float scale_in = 1.f,
float scale_wei = 1.f,
float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C, typename D>
__host__ __device__ void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ void
operator()<f8_t, float, float>(f8_t& e, const float& c, const float& d) const
{
float x;
Add{}.template operator()<float>(x, c * scale_in_ * scale_wei_, d);
e = type_convert<f8_t>(x * scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -249,6 +249,31 @@ struct MultiplyAdd
}
};
struct MultiplyMultiply
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
ck::half_t& e, const float& c, const float& d0, const float& d1) const
{
const float x0_f = c * d0 * d1;
e = ck::type_convert<ck::half_t>(x0_f);
}
template <>
__host__ __device__ constexpr void operator()<ck::bhalf_t, float, float, float>(
ck::bhalf_t& e, const float& c, const float& d0, const float& d1) const
{
const float x0_f = c * d0 * d1;
e = ck::type_convert<ck::bhalf_t>(x0_f);
}
};
struct MultiplyAddFastGelu
{
template <typename E, typename C, typename D0, typename D1>
......
......@@ -431,7 +431,7 @@ struct Relu
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "__expf" and "rcp" function
// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
struct FastGelu
{
template <typename Y, typename X>
......@@ -451,7 +451,7 @@ struct FastGelu
y = x / (1.f + emu);
}
// device code, use lower precision "__expf" and "rcp"
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template <>
__device__ void operator()<float, float>(float& y, const float& x) const
{
......@@ -459,7 +459,7 @@ struct FastGelu
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = __expf(u);
const float emu = __ocml_exp_f32(u);
y = x * ck::math::rcp(1.f + emu);
}
......@@ -1025,6 +1025,31 @@ struct ConvScale
float scale_out_;
};
struct ConvScaleRelu
{
__host__ __device__ ConvScaleRelu(float scale_in = 1.f,
float scale_wei = 1.f,
float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C>
__host__ __device__ void operator()(E& e, const C& c) const;
template <>
__host__ __device__ void operator()<f8_t, float>(f8_t& e, const float& c) const
{
float x;
Relu{}.template operator()<float>(x, c * scale_in_ * scale_wei_);
e = type_convert<f8_t>(x * scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
// support fastconvert of int8 to fp16
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -258,7 +258,7 @@ struct GridwiseMultipleReduction_mk_to_m_multiblock
if(thread_k_cluster_id == 0)
{
if(block_group_size == 0 && !float_equal_zero{}(beta_values[iR]))
if(!float_equal_zero{}(beta_values[iR]))
{
StaticBuffer<AddressSpaceEnum::Vgpr, OutDataType, MThreadSliceSize, true>
priorDstValueBuf;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment