Commit 7e493730 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge branch 'develop' into wavelet_model

parents b89a88b5 40942b90
......@@ -503,13 +503,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
if(!DeviceOp::IsSupportedArgument(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
throw std::runtime_error("wrong! unsupported argument");
}
const index_t grid_size =
......
......@@ -333,10 +333,6 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -370,12 +366,19 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument
struct Argument : public BaseArgument
......@@ -478,10 +481,9 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// for calculating batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
......@@ -520,21 +522,21 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_batched_gemm_xdl<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2ETileMap,
has_main_loop>;
const auto kernel =
kernel_batched_gemm_xdl<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
......
......@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
......@@ -57,7 +58,8 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -88,7 +90,8 @@ __global__ void
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
block_2_ctile_map,
c0_matrix_mask);
#else
ignore = p_a_grid;
ignore = p_b_grid;
......@@ -106,6 +109,7 @@ __global__ void
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -168,6 +172,7 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool MaskOutUpperTriangle,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
......@@ -194,9 +199,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
// FIXME: pad K
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
......@@ -398,6 +400,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {}));
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {}));
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
{
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
}
private:
// index_t MRaw_;
index_t NRaw_;
};
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
......@@ -498,7 +523,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
matrix_padder.PadN>;
matrix_padder.PadN,
MaskOutUpperTriangle>;
// Argument
// FIXME: constness
......@@ -548,6 +574,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
batch_count_(Batch),
compute_base_ptr_of_batch_{
BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_},
c0_matrix_mask_{NRaw},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw},
c_extent_lowest_{c_gs_ms_gemm1ns_lengths.back()},
c_stride_lowest_{c_gs_ms_gemm1ns_strides.back()}
......@@ -585,6 +612,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// check C0 masking and padding
C0MatrixMask c0_matrix_mask_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_;
index_t c_extent_lowest_;
......@@ -632,6 +662,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_>;
return launch_and_time_kernel(stream_config,
......@@ -654,7 +685,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_,
arg.batch_count_,
arg.compute_base_ptr_of_batch_);
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......
......@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
......@@ -57,7 +58,8 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -88,7 +90,8 @@ __global__ void
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map);
block_2_ctile_map,
c0_matrix_mask);
#else
ignore = p_a_grid;
ignore = p_b_grid;
......@@ -106,6 +109,7 @@ __global__ void
ignore = block_2_ctile_map;
ignore = batch_count;
ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -177,6 +181,7 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool MaskOutUpperTriangle,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemm<ALayout,
......@@ -203,9 +208,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
// FIXME: pad K
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
......@@ -313,6 +315,29 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
}
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
{
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
}
private:
// index_t MRaw_;
index_t NRaw_;
};
struct ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
......@@ -418,7 +443,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
matrix_padder.PadN>;
matrix_padder.PadN,
MaskOutUpperTriangle>;
// Argument
struct Argument : public BaseArgument
......@@ -463,6 +489,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_element_op_{c_element_op},
batch_count_(Batch),
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC},
c0_matrix_mask_{NRaw},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
......@@ -497,6 +524,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// check C0 masking and padding
C0MatrixMask c0_matrix_mask_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_;
};
......@@ -542,6 +572,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_>;
return launch_and_time_kernel(stream_config,
......@@ -564,7 +595,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_,
arg.batch_count_,
arg.compute_base_ptr_of_batch_);
arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......
......@@ -320,10 +320,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -357,12 +353,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument
struct Argument : public BaseArgument
......@@ -475,10 +478,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
......@@ -535,9 +537,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2ETileMap,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::Block2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
......
......@@ -222,14 +222,9 @@ struct DeviceElementwise
}
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
static bool IsSupportedArgument(const Argument& arg)
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
if(arg.lengths_.back() % MPerThread != 0)
return false;
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
......@@ -247,19 +242,40 @@ struct DeviceElementwise
bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(
pArg->lengths_, pArg->inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
valid = false;
});
static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid(
pArg->lengths_, pArg->outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
valid = false;
});
return valid;
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto
MakeArgument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
{
return Argument{lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op};
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
......
......@@ -237,10 +237,6 @@ struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......
......@@ -365,10 +365,6 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -402,17 +398,21 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
struct GroupedContractionBlock2ETileMap
{
static_assert(
std::is_same<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{})),
typename GridwiseGemm::DefaultBlock2ETileMap>::value,
"Wrong! Should be the same type name");
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
GroupedContractionBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n,
ck::index_t BlockStart)
......@@ -441,7 +441,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
return default_block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
}
typename GridwiseGemm::DefaultBlock2ETileMap default_block_2_etile_map_;
Block2ETileMap default_block_2_etile_map_;
ck::index_t block_start_;
};
......@@ -456,10 +456,9 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// lock-to-e-tile map
GroupedContractionBlock2ETileMap block_2_etile_map_;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Conv backward data multiple D:
// input : output image A[G, N, K, Ho, Wo]
// input : weight B[G, K, C, Y, X],
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
// output : input image E[G, N, C, Hi, Wi],
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a, // output image
const void* p_b, // weight
const std::array<const void*, NumDTensor>& p_ds, // bias
void* p_e, // input image
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths, // bias
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides, // bias
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment