Commit 7e493730 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge branch 'develop' into wavelet_model

parents b89a88b5 40942b90
...@@ -503,13 +503,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -503,13 +503,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!DeviceOp::IsSupportedArgument(arg))
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! unsupported argument");
} }
const index_t grid_size = const index_t grid_size =
......
...@@ -333,10 +333,6 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -333,10 +333,6 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -370,12 +366,19 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -370,12 +366,19 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; // block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -478,10 +481,9 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -478,10 +481,9 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// for calculating batch offset // for calculating batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
...@@ -520,8 +522,8 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -520,8 +522,8 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_batched_gemm_xdl< const auto kernel =
GridwiseGemm, kernel_batched_gemm_xdl<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer, typename GridwiseGemm::DsGridPointer,
EDataType, EDataType,
...@@ -530,8 +532,8 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -530,8 +532,8 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD<ALayout,
CDEElementwiseOperation, CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch,
Block2ETileMap, Block2ETileMap,
has_main_loop>; has_main_loop>;
......
...@@ -35,6 +35,7 @@ template <typename GridwiseGemm, ...@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -57,7 +58,8 @@ __global__ void ...@@ -57,7 +58,8 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -88,7 +90,8 @@ __global__ void ...@@ -88,7 +90,8 @@ __global__ void
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map); block_2_ctile_map,
c0_matrix_mask);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -106,6 +109,7 @@ __global__ void ...@@ -106,6 +109,7 @@ __global__ void
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -168,6 +172,7 @@ template <typename ALayout, ...@@ -168,6 +172,7 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool MaskOutUpperTriangle,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemmPermute<ALayout, : public DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
...@@ -194,9 +199,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -194,9 +199,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{ GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
// FIXME: pad K
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
...@@ -398,6 +400,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -398,6 +400,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {})); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {}));
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {}));
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
{
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
}
private:
// index_t MRaw_;
index_t NRaw_;
};
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
...@@ -498,7 +523,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -498,7 +523,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
matrix_padder.PadN>; matrix_padder.PadN,
MaskOutUpperTriangle>;
// Argument // Argument
// FIXME: constness // FIXME: constness
...@@ -548,6 +574,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -548,6 +574,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
batch_count_(Batch), batch_count_(Batch),
compute_base_ptr_of_batch_{ compute_base_ptr_of_batch_{
BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_}, BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_},
c0_matrix_mask_{NRaw},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}, raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw},
c_extent_lowest_{c_gs_ms_gemm1ns_lengths.back()}, c_extent_lowest_{c_gs_ms_gemm1ns_lengths.back()},
c_stride_lowest_{c_gs_ms_gemm1ns_strides.back()} c_stride_lowest_{c_gs_ms_gemm1ns_strides.back()}
...@@ -585,6 +612,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -585,6 +612,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// check C0 masking and padding
C0MatrixMask c0_matrix_mask_;
// For robust IsSupportedArgument() check // For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_; std::vector<index_t> raw_lengths_m_n_k_o_;
index_t c_extent_lowest_; index_t c_extent_lowest_;
...@@ -632,6 +662,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -632,6 +662,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_>; has_main_k_block_loop_>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
...@@ -654,7 +685,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -654,7 +685,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.compute_base_ptr_of_batch_); arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......
...@@ -35,6 +35,7 @@ template <typename GridwiseGemm, ...@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
typename C0MatrixMask,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
...@@ -57,7 +58,8 @@ __global__ void ...@@ -57,7 +58,8 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -88,7 +90,8 @@ __global__ void ...@@ -88,7 +90,8 @@ __global__ void
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map); block_2_ctile_map,
c0_matrix_mask);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
...@@ -106,6 +109,7 @@ __global__ void ...@@ -106,6 +109,7 @@ __global__ void
ignore = block_2_ctile_map; ignore = block_2_ctile_map;
ignore = batch_count; ignore = batch_count;
ignore = compute_base_ptr_of_batch; ignore = compute_base_ptr_of_batch;
ignore = c0_matrix_mask;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -177,6 +181,7 @@ template <typename ALayout, ...@@ -177,6 +181,7 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
bool MaskOutUpperTriangle,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
: public DeviceBatchedGemmSoftmaxGemm<ALayout, : public DeviceBatchedGemmSoftmaxGemm<ALayout,
...@@ -203,9 +208,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -203,9 +208,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{ GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
// FIXME: pad K
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
...@@ -313,6 +315,29 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -313,6 +315,29 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
} }
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct C0MatrixMask
{
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
}
private:
// index_t MRaw_;
index_t NRaw_;
};
struct ComputeBasePtrOfStridedBatch struct ComputeBasePtrOfStridedBatch
{ {
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
...@@ -418,7 +443,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -418,7 +443,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched, LoopSched,
matrix_padder.PadN>; matrix_padder.PadN,
MaskOutUpperTriangle>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -463,6 +489,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -463,6 +489,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
c_element_op_{c_element_op}, c_element_op_{c_element_op},
batch_count_(Batch), batch_count_(Batch),
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC}, compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC},
c0_matrix_mask_{NRaw},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw} raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
...@@ -497,6 +524,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -497,6 +524,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// check C0 masking and padding
C0MatrixMask c0_matrix_mask_;
// For robust IsSupportedArgument() check // For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_; std::vector<index_t> raw_lengths_m_n_k_o_;
}; };
...@@ -542,6 +572,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -542,6 +572,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch, ComputeBasePtrOfStridedBatch,
C0MatrixMask,
has_main_k_block_loop_>; has_main_k_block_loop_>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
...@@ -564,7 +595,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -564,7 +595,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_, arg.block_2_ctile_map_,
arg.batch_count_, arg.batch_count_,
arg.compute_base_ptr_of_batch_); arg.compute_base_ptr_of_batch_,
arg.c0_matrix_mask_);
}; };
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
......
...@@ -320,10 +320,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -320,10 +320,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -357,12 +353,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -357,12 +353,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; // block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -475,10 +478,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -475,10 +478,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map // block-to-e-tile map
Block2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
...@@ -535,9 +537,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -535,9 +537,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
CDEElementwiseOperation, CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2ETileMap, DeviceOp::Block2ETileMap,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
......
...@@ -222,14 +222,9 @@ struct DeviceElementwise ...@@ -222,14 +222,9 @@ struct DeviceElementwise
} }
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override static bool IsSupportedArgument(const Argument& arg)
{ {
const Argument* pArg = dynamic_cast<const Argument*>(p_arg); if(arg.lengths_.back() % MPerThread != 0)
if(pArg == nullptr)
return false;
if(pArg->lengths_.back() % MPerThread != 0)
return false; return false;
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths, auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
...@@ -247,19 +242,40 @@ struct DeviceElementwise ...@@ -247,19 +242,40 @@ struct DeviceElementwise
bool valid = true; bool valid = true;
static_for<0, NumInput, 1>{}([&](auto I) { static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid( if(!IsScalarPerVectorValid(
pArg->lengths_, pArg->inStridesArray_[I.value], InScalarPerVectorSeq::At(I))) arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
valid = false; valid = false;
}); });
static_for<0, NumOutput, 1>{}([&](auto I) { static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid( if(!IsScalarPerVectorValid(
pArg->lengths_, pArg->outStridesArray_[I.value], OutScalarPerVectorSeq::At(I))) arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
valid = false; valid = false;
}); });
return valid; return valid;
}; };
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto
MakeArgument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
{
return Argument{lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op};
}
std::unique_ptr<BaseArgument> std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths, MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray, const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
......
...@@ -237,10 +237,6 @@ struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -237,10 +237,6 @@ struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
......
...@@ -234,6 +234,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -234,6 +234,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
Number<NumDTensor>{}); Number<NumDTensor>{});
} }
// desc for problem definition
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>; using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
...@@ -250,10 +251,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -250,10 +251,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -287,10 +284,19 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -287,10 +284,19 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -326,7 +332,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -326,7 +332,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op} cde_element_op_{cde_element_op},
MRaw_{MRaw},
NRaw_{NRaw},
KRaw_{KRaw}
{ {
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
...@@ -383,18 +392,22 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -383,18 +392,22 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map // block-to-e-tile map
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
// element-wise op // element-wise op
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// for checking vector load/store
index_t MRaw_;
index_t NRaw_;
index_t KRaw_;
}; };
// Invoker // Invoker
...@@ -429,9 +442,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -429,9 +442,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
CDEElementwiseOperation, CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2ETileMap, DeviceOp::Block2ETileMap,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
...@@ -480,6 +493,86 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -480,6 +493,86 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return false; return false;
} }
// check vector load/store
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector laod of B
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of Ds
// only support RowMajor for now
bool all_valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(!is_same_v<DLayout, Row>)
{
all_valid = false;
}
});
if(!all_valid)
{
return false;
}
// check vector store of E
// only support RowMajor for now
if constexpr(is_same_v<ELayout, Row>)
{
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
return false;
}
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
......
...@@ -365,10 +365,6 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -365,10 +365,6 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -402,17 +398,21 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -402,17 +398,21 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
struct GroupedContractionBlock2ETileMap struct GroupedContractionBlock2ETileMap
{ {
static_assert( // block-to-e-tile map
std::is_same<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{})), using Block2ETileMap =
typename GridwiseGemm::DefaultBlock2ETileMap>::value, remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
"Wrong! Should be the same type name");
GroupedContractionBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, GroupedContractionBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n,
ck::index_t BlockStart) ck::index_t BlockStart)
...@@ -441,7 +441,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -441,7 +441,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
return default_block_2_etile_map_.CheckValidity(e_grid_desc_m_n); return default_block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
} }
typename GridwiseGemm::DefaultBlock2ETileMap default_block_2_etile_map_; Block2ETileMap default_block_2_etile_map_;
ck::index_t block_start_; ck::index_t block_start_;
}; };
...@@ -456,10 +456,9 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle ...@@ -456,10 +456,9 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// lock-to-e-tile map // lock-to-e-tile map
GroupedContractionBlock2ETileMap block_2_etile_map_; GroupedContractionBlock2ETileMap block_2_etile_map_;
......
...@@ -34,11 +34,13 @@ struct DeviceGroupedConvFwdMultipleD : public BaseOperator ...@@ -34,11 +34,13 @@ struct DeviceGroupedConvFwdMultipleD : public BaseOperator
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer( virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
const void* p_a, const void* p_a, // input image
const void* p_b, const void* p_b, // weight
const std::array<const void*, NumDTensor>& p_ds, const std::array<const void*, NumDTensor>& p_ds,
void* p_e, void* p_e, // output image
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides, const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <cmath>
#include <memory>
#include <type_traits>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t NumDim, typename InDataType, typename OutDataType, typename ElementwiseOperation>
struct DevicePermute : BaseOperator
{
using Lengths = std::array<index_t, NumDim>;
using Strides = Lengths;
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const Lengths& in_lengths,
const Strides& in_strides,
const Lengths& out_lengths,
const Strides& out_strides,
const void* in_dev_buffer,
void* out_dev_buffer,
ElementwiseOperation elementwise_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment