Commit 7e493730 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge branch 'develop' into wavelet_model

parents b89a88b5 40942b90
...@@ -503,13 +503,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -503,13 +503,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!DeviceOp::IsSupportedArgument(arg))
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! unsupported argument");
} }
const index_t grid_size = const index_t grid_size =
......
...@@ -333,10 +333,6 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -333,10 +333,6 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -370,12 +366,19 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -370,12 +366,19 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; // block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -478,10 +481,9 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -478,10 +481,9 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// for calculating batch offset // for calculating batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
...@@ -520,21 +522,21 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -520,21 +522,21 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_batched_gemm_xdl< const auto kernel =
GridwiseGemm, kernel_batched_gemm_xdl<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer, typename GridwiseGemm::DsGridPointer,
EDataType, EDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch,
Block2ETileMap, Block2ETileMap,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
kernel, kernel,
......
...@@ -35,6 +35,7 @@ template <typename GridwiseGemm, ...@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -57,7 +58,8 @@ __global__ void ...@@ -57,7 +58,8 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -88,7 +90,8 @@ __global__ void ...@@ -88,7 +90,8 @@ __global__ void
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map); block_2_ctile_map,
c0_matrix_mask);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -106,6 +109,7 @@ __global__ void ...@@ -106,6 +109,7 @@ __global__ void
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -168,6 +172,7 @@ template <typename ALayout, ...@@ -168,6 +172,7 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool MaskOutUpperTriangle,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermute<ALayout, : public DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
...@@ -194,9 +199,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -194,9 +199,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{ GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
// FIXME: pad K
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
...@@ -398,6 +400,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -398,6 +400,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {}));
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {}));
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
{
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
}
private:
// index_t MRaw_;
index_t NRaw_;
};
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
...@@ -498,7 +523,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -498,7 +523,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
matrix_padder.PadN>; matrix_padder.PadN,
MaskOutUpperTriangle>;
// Argument // Argument
// FIXME: constness // FIXME: constness
...@@ -548,6 +574,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -548,6 +574,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
batch_count_(Batch), batch_count_(Batch),
compute_base_ptr_of_batch_{ compute_base_ptr_of_batch_{
BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_}, BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_},
c0_matrix_mask_{NRaw},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}, raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw},
c_extent_lowest_{c_gs_ms_gemm1ns_lengths.back()}, c_extent_lowest_{c_gs_ms_gemm1ns_lengths.back()},
c_stride_lowest_{c_gs_ms_gemm1ns_strides.back()} c_stride_lowest_{c_gs_ms_gemm1ns_strides.back()}
...@@ -585,6 +612,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -585,6 +612,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// check C0 masking and padding
C0MatrixMask c0_matrix_mask_;
// For robust IsSupportedArgument() check // For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_; std::vector<index_t> raw_lengths_m_n_k_o_;
index_t c_extent_lowest_; index_t c_extent_lowest_;
...@@ -632,6 +662,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -632,6 +662,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_>; has_main_k_block_loop_>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
...@@ -654,7 +685,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -654,7 +685,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.compute_base_ptr_of_batch_); arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......
...@@ -35,6 +35,7 @@ template <typename GridwiseGemm, ...@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -57,7 +58,8 @@ __global__ void ...@@ -57,7 +58,8 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -88,7 +90,8 @@ __global__ void ...@@ -88,7 +90,8 @@ __global__ void
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map); block_2_ctile_map,
c0_matrix_mask);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -106,6 +109,7 @@ __global__ void ...@@ -106,6 +109,7 @@ __global__ void
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -177,6 +181,7 @@ template <typename ALayout, ...@@ -177,6 +181,7 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool MaskOutUpperTriangle,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemm<ALayout, : public DeviceBatchedGemmSoftmaxGemm<ALayout,
...@@ -203,9 +208,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -203,9 +208,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{ GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
// FIXME: pad K
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
...@@ -313,6 +315,29 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -313,6 +315,29 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
} }
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
{
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
}
private:
// index_t MRaw_;
index_t NRaw_;
};
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
...@@ -418,7 +443,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -418,7 +443,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
matrix_padder.PadN>; matrix_padder.PadN,
MaskOutUpperTriangle>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -463,6 +489,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -463,6 +489,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_element_op_{c_element_op}, c_element_op_{c_element_op},
batch_count_(Batch), batch_count_(Batch),
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC}, compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC},
c0_matrix_mask_{NRaw},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw} raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
...@@ -497,6 +524,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -497,6 +524,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// check C0 masking and padding
C0MatrixMask c0_matrix_mask_;
// For robust IsSupportedArgument() check // For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_; std::vector<index_t> raw_lengths_m_n_k_o_;
}; };
...@@ -542,6 +572,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -542,6 +572,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_>; has_main_k_block_loop_>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
...@@ -564,7 +595,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -564,7 +595,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.compute_base_ptr_of_batch_); arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......
...@@ -320,10 +320,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -320,10 +320,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -357,12 +353,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -357,12 +353,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; // block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -475,10 +478,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -475,10 +478,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map // block-to-e-tile map
Block2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
...@@ -535,9 +537,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -535,9 +537,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
CDEElementwiseOperation, CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2ETileMap, DeviceOp::Block2ETileMap,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
......
...@@ -222,14 +222,9 @@ struct DeviceElementwise ...@@ -222,14 +222,9 @@ struct DeviceElementwise
} }
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override static bool IsSupportedArgument(const Argument& arg)
{ {
const Argument* pArg = dynamic_cast<const Argument*>(p_arg); if(arg.lengths_.back() % MPerThread != 0)
if(pArg == nullptr)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false; return false;
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths, auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
...@@ -247,19 +242,40 @@ struct DeviceElementwise ...@@ -247,19 +242,40 @@ struct DeviceElementwise
bool valid = true; bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) { static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid( if(!IsScalarPerVectorValid(
pArg->lengths_, pArg->inStridesArray_[I.value], InScalarPerVectorSeq::At(I))) arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
valid = false; valid = false;
}); });
static_for<0, NumOutput, 1>{}([&](auto I) { static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid( if(!IsScalarPerVectorValid(
pArg->lengths_, pArg->outStridesArray_[I.value], OutScalarPerVectorSeq::At(I))) arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
valid = false; valid = false;
}); });
return valid; return valid;
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto
MakeArgument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
{
return Argument{lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op};
}
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths, MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray, const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
......
...@@ -237,10 +237,6 @@ struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -237,10 +237,6 @@ struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
......
...@@ -365,10 +365,6 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -365,10 +365,6 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -402,17 +398,21 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -402,17 +398,21 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
struct GroupedContractionBlock2ETileMap struct GroupedContractionBlock2ETileMap
{ {
static_assert( // block-to-e-tile map
std::is_same<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{})), using Block2ETileMap =
typename GridwiseGemm::DefaultBlock2ETileMap>::value, remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
"Wrong! Should be the same type name");
GroupedContractionBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, GroupedContractionBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n,
ck::index_t BlockStart) ck::index_t BlockStart)
...@@ -441,7 +441,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -441,7 +441,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
return default_block_2_etile_map_.CheckValidity(e_grid_desc_m_n); return default_block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
} }
typename GridwiseGemm::DefaultBlock2ETileMap default_block_2_etile_map_; Block2ETileMap default_block_2_etile_map_;
ck::index_t block_start_; ck::index_t block_start_;
}; };
...@@ -456,10 +456,9 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -456,10 +456,9 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// lock-to-e-tile map // lock-to-e-tile map
GroupedContractionBlock2ETileMap block_2_etile_map_; GroupedContractionBlock2ETileMap block_2_etile_map_;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// Conv backward data multiple D:
// input : output image A[G, N, K, Ho, Wo]
// input : weight B[G, K, C, Y, X],
// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
// output : input image E[G, N, C, Hi, Wi],
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a, // output image
const void* p_b, // weight
const std::array<const void*, NumDTensor>& p_ds, // bias
void* p_e, // input image
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_lengths, // bias
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
ds_g_n_k_wos_strides, // bias
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment