Commit 17cf8179 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop-0605' into amd-master

parents 6b9a4bd5 e4112de7
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// 1d
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NWGK_GKXC_NWGC()
{
return is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NWGK>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNWK_GKXC_GNWC()
{
return is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
}
// 2d
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NHWGK_GKYXC_NHWGC()
{
return is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NHWGK>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNHWK_GKYXC_GNHWC()
{
return is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
}
// 3d
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NDHWGK_GKZYXC_NDHWGC()
{
return is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NDHWGK>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
{
return is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>;
}
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
struct ComputePtrOffsetOfStridedBatch
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename GemmDesc,
GemmSpecialization GemmSpec,
typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename Block2ETileMap,
typename GroupedGemmBlock2ETileMap,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const index_t grid_size_grp,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t KBatch = 1;
const index_t block_id = get_block_1d_id();
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
const index_t group_id = block_id / grid_size_grp;
if(group_id >= group_count)
return;
const index_t M = gemm_desc_ptr[group_id].M;
const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0)
return;
const auto StrideAs = gemm_desc_ptr[group_id].StrideAs;
const auto StrideBs = gemm_desc_ptr[group_id].StrideBs;
const auto StrideDs = gemm_desc_ptr[group_id].StrideDs;
const auto StrideE = gemm_desc_ptr[group_id].StrideE;
const auto e_grid_desc_m_n =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
const index_t BlockStart = group_id * grid_size_grp;
const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch};
const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n);
constexpr auto NumATensor = GridwiseGemm::AsGridPointer::Size();
constexpr auto NumBTensor = GridwiseGemm::BsGridPointer::Size();
constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size();
typename GridwiseGemm::AsGridPointer p_as_grid_;
typename GridwiseGemm::BsGridPointer p_bs_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType = remove_cvref_t<decltype(p_as_grid_(i))>;
p_as_grid_(i) = static_cast<ADataType>(gemm_desc_ptr[group_id].p_as_grid[i]);
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType = remove_cvref_t<decltype(p_bs_grid_(i))>;
p_bs_grid_(i) = static_cast<BDataType>(gemm_desc_ptr[group_id].p_bs_grid[i]);
});
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<decltype(p_ds_grid_(i))>;
p_ds_grid_(i) = static_cast<DDataType>(gemm_desc_ptr[group_id].p_ds_grid[i]);
});
index_t id_off = 0;
index_t id_local = get_block_1d_id() - BlockStart;
while(id_local < local_grid_size)
{
const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
GridwiseGemm::
template Run<HasMainKBlockLoop, GemmSpec, AsLayout, BsLayout, DsLayout, ELayout>(
p_as_grid_,
p_bs_grid_,
p_ds_grid_,
gemm_desc_ptr[group_id].p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
M,
N,
K,
StrideAs,
StrideBs,
StrideDs,
StrideE,
block_2_etile_map);
id_off += grid_size_grp;
id_local += grid_size_grp;
}
#else
ignore = gemm_descs_const;
ignore = group_count;
ignore = grid_size_grp;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
#endif
}
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename ComputeType = EDataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
: public DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK;
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t NumGemmKPrefetchStage = 1;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle<
AsDataType,
BsDataType,
ComputeType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMapMLoops
{
using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__ OffsettedBlockToCTileMapMLoops(
UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0)
{
block_to_ctile_map_ = block_to_ctile_map;
block_start_ = block_start;
id_off_ = id_off;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto idx_bot = block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_));
return make_tuple(
// idx_bot[Number<0>{}],
idx_bot[Number<1>{}],
idx_bot[Number<2>{}]);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
template <typename CGridDesc_M_N>
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
}
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_;
index_t id_off_;
};
template <index_t MPerBlock_, index_t NPerBlock_>
struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default;
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M,
index_t N,
index_t KBatch,
index_t M01 = 8)
: M_(M), N_(N), KBatch_(KBatch), M01_(M01)
{
}
template <typename CGridDesc_M_N>
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8)
: BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01)
{
}
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return M0 * N0 * KBatch_;
}
template <typename CGridDesc_M_N>
__host__ __device__ constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(M_, MPerBlock_);
const auto N0 = math::integer_divide_ceil(N_, NPerBlock_);
block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups
const index_t idx_ksplit = block_1d_id / (M0 * N0);
block_1d_id = block_1d_id % (M0 * N0);
index_t idx_N0 = block_1d_id % N0;
index_t idx_M0 = block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
index_t idx_M00 = idx_M0 / M01_;
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return make_tuple(idx_ksplit,
idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t M_;
index_t N_;
index_t KBatch_;
index_t M01_;
};
using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops<MPerBlock, NPerBlock>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
struct GemmBiasTransKernelArg
{
// pointers
std::array<const void*, NumATensor> as_ptr_;
std::array<const void*, NumBTensor> bs_ptr_;
std::array<const void*, NumDTensor> ds_ptr_;
void* e_ptr_;
index_t M_, N_, K_;
std::array<index_t, NumATensor> StrideAs_;
std::array<index_t, NumBTensor> StrideBs_;
std::array<index_t, NumDTensor> StrideDs_;
index_t StrideE_;
};
// Argument
struct Argument : public BaseArgument
{
void UpdateKBatch(index_t) {}
Argument(std::vector<std::array<const void*, NumATensor>>&,
std::vector<std::array<const void*, NumBTensor>>&,
std::vector<std::array<const void*, NumDTensor>>&,
std::vector<void*>&,
std::vector<GemmMultiABDDesc>& gemm_descs,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
CDEElementwiseOperation c_element_op = CDEElementwiseOperation{})
: a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}
{
grid_size_ = 0;
k_batch_ = 1;
grouped_gemm_kernel_args_dev = nullptr;
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
gemm_desc_kernel_arg_.reserve(group_count_);
index_t group_id = 0;
sum_of_m = gemm_descs[0].M_;
const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_);
const index_t N = gemm_descs[0].N_;
const index_t K = gemm_descs[0].K_;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_)
{
throw std::runtime_error("wrong! M/N/K is not identical");
}
a_mtx_mraw_kraw_.emplace_back(sum_of_m, K);
b_mtx_nraw_kraw_.emplace_back(N, K);
// pointer
std::array<const void*, NumATensor> p_as_grid;
std::array<const void*, NumBTensor> p_bs_grid;
std::array<const void*, NumDTensor> p_ds_grid;
static_for<0, NumATensor, 1>{}([&](auto j) { p_as_grid[j] = nullptr; });
static_for<0, NumBTensor, 1>{}([&](auto j) { p_bs_grid[j] = nullptr; });
static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; });
std::array<index_t, NumATensor> StrideAs;
std::array<index_t, NumBTensor> StrideBs;
std::array<index_t, NumDTensor> StrideDs;
const index_t StrideE = gemm_descs[i].stride_C_;
if(gemm_descs[i].stride_As_.size() != NumATensor)
{
throw std::runtime_error(
"wrong! gemm_descs[i].stride_As_.size() does not match NumATensor");
}
static_for<0, NumATensor, 1>{}(
[&](auto j) { StrideAs[j] = gemm_descs[i].stride_As_[j]; });
if(gemm_descs[i].stride_Bs_.size() != NumBTensor)
{
throw std::runtime_error(
"wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor");
}
static_for<0, NumBTensor, 1>{}(
[&](auto j) { StrideBs[j] = gemm_descs[i].stride_Bs_[j]; });
if(gemm_descs[i].stride_Ds_.size() != NumDTensor)
{
throw std::runtime_error(
"wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor");
}
static_for<0, NumDTensor, 1>{}(
[&](auto j) { StrideDs[j] = gemm_descs[i].stride_Ds_[j]; });
const auto e_grid_desc_m_n =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
AverM, N, StrideE);
// block-to-e-tile map
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_};
grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
if(group_id * grid_size_grp_ != grid_size_)
{
throw std::runtime_error("wrong! grid_size_grp_ is not identical!");
}
grid_size_ += grid_size_grp_;
// check block-to-E-tile
if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n))
{
throw std::runtime_error("wrong! block_2_etile_map validation failed");
}
gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{
p_as_grid,
p_bs_grid,
p_ds_grid,
nullptr,
AverM,
N,
K,
StrideAs,
StrideBs,
StrideDs,
StrideE,
});
group_id++;
}
const auto e_grid_desc_sum_m_n =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_);
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1};
barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n);
}
// private:
index_t group_count_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation c_element_op_;
std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
const void* grouped_gemm_kernel_args_dev;
index_t grid_size_;
index_t grid_size_grp_;
index_t barrier_size_grp_;
index_t sum_of_m;
index_t k_batch_ = 1;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
bool has_main_k_block_loop = true;
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) !=
has_main_k_block_loop)
{
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
}
}
if(arg.grouped_gemm_kernel_args_dev == nullptr)
{
throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr");
}
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) {
const auto kernel = kernel_grouped_gemm_xdl_fixed_nk<
GridwiseGemm,
GroupedGemmMultiABDKernelArgument<NumATensor, NumBTensor, NumDTensor>,
GemmSpec,
AsLayout,
BsLayout,
DsLayout,
ELayout,
Block2ETileMap,
GroupedGemmBlock2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
e_global_memory_operation_,
has_main_k_block_loop_>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
arg.gemm_desc_kernel_arg_.size(),
arg.grid_size_grp_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
};
constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd;
constexpr auto Set = InMemoryDataOperationEnum::Set;
if(arg.k_batch_ > 1)
{
if(has_main_k_block_loop)
{
ave_time =
launch_kernel(integral_constant<bool, true>{},
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
}
else
{
ave_time =
launch_kernel(integral_constant<bool, false>{},
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
}
}
else
{
if(has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<InMemoryDataOperationEnum, Set>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<InMemoryDataOperationEnum, Set>{});
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
{
return false;
}
bool supported = true;
// If we use padding we do not support vector loads for dimensions not divisible by vector
// load size.
if constexpr(GemmSpec != GemmSpecialization::Default)
{
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
// thus we have to adapt it to the {M,K} or {N,K} layout.
const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
for(index_t i = 0; i < arg.group_count_; ++i)
{
const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0);
supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0);
}
}
return supported;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(std::vector<std::array<const void*, NumATensor>>& p_As,
std::vector<std::array<const void*, NumBTensor>>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmMultiABDDesc> gemm_descs,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
CDEElementwiseOperation c_element_op = CDEElementwiseOperation{})
{
return Argument{
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<std::array<const void*, NumATensor>>& p_As,
std::vector<std::array<const void*, NumBTensor>>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmMultiABDDesc>& gemm_descs,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) override
{
return std::make_unique<Argument>(
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedGemm_Xdl_Fixed_NK"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">";
// clang-format on
return str.str();
}
static void SetElementwiseOps(Argument& arg,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op)
{
arg.a_element_op_ = a_element_op;
arg.b_element_op_ = b_element_op;
arg.c_element_op_ = c_element_op;
}
static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args)
{
arg.grouped_gemm_kernel_args_dev = kernel_args;
}
// polymorphic
void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), kernel_args);
}
void SetElementwiseOps(BaseArgument* p_arg,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op) const override
{
SetElementwiseOps(
*dynamic_cast<Argument*>(p_arg), a_element_op, b_element_op, c_element_op);
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
auto arg = *dynamic_cast<const Argument*>(p_arg);
return arg.group_count_ *
sizeof(GroupedGemmMultiABDKernelArgument<NumATensor, NumBTensor, NumDTensor>);
}
#if 0
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = *dynamic_cast<const Argument*>(p_arg);
return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t);
}
void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& stream_config = StreamConfig{}) const override
{
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
p_arg_->p_workspace_ = p_workspace;
hip_check_error(
hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_));
}
#endif
static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
// polymorphic
void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override
{
return SetKBatch(*dynamic_cast<Argument*>(p_arg), k_batch);
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -553,24 +553,29 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}"
<< std::endl;
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2)
<< "}" << std::endl;
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}"
<< std::endl;
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2)
<< "}" << std::endl;
std::cout << ", arg.e_grid_desc_m_n_{ "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
#endif
}
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
......@@ -668,7 +673,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
}
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_navi2_supported() || ck::is_navi3_supported())
ck::is_gfx103_supported() || ck::is_gfx11_supported())
{
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
......
......@@ -19,7 +19,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
......@@ -252,7 +252,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
Sequence<0, 1>,
ElementwiseInputSequence,
ck::Sequence<CDEShuffleBlockTransferScalarPerVector_NPerBlock>,
true>;
I1,
I1>;
// Block2CTileMap configuration parameter.
static constexpr index_t B2E_M01 = 8;
......@@ -336,6 +337,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
elementwise_d_grid_descs_m_n_.reserve(group_count_);
ds_grid_pointer_.reserve(group_count_);
group_grid_size_.reserve(group_count_);
e_ptrs_.reserve(group_count_);
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
{
......@@ -379,7 +381,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
group_grid_size_[i] = grid_size_grp;
group_grid_size_.push_back(grid_size_grp);
// block-to-e-tile map
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
......@@ -420,9 +422,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n);
elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n);
ds_grid_pointer_.push_back(p_ds_grid);
}
// Store a copy of E pointers for elementwise kernel destination
e_ptrs_ = p_Es;
e_ptrs_.push_back(p_Es[i]);
}
}
/**
......@@ -466,7 +468,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
gemm_kernel_args_[i].block_start_ = block_start;
gemm_kernel_args_[i].block_end_ = block_end;
#if DEBUG_LOG
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
index_t tiles = (block_end - block_start) / K_BATCH;
std::cout << "block_start: " << block_start << "\n"
<< "block_end: " << block_end << "\n"
......@@ -477,7 +480,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<< "K0Padded: " << karg.K0Padded << std::endl
<< "KBatch: " << karg.k_batch << std::endl
<< "grid_size_: " << karg.KPadded << std::endl;
#endif
}
}
}
......@@ -492,12 +495,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
arg.karg_.p_c_grid = p_workspace + offset;
index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
offset += tiles * MPerBlock * NPerBlock;
#if DEBUG_LOG
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "block_start: " << arg.block_start_ << "\n"
<< "block_end: " << arg.block_end_ << "\n"
<< "tiles: " << tiles << "\n"
<< "offset: " << offset << std::endl;
#endif
}
}
}
......@@ -771,13 +775,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(dev_gemm_args),
arg.group_count_,
arg.gemm_kernel_args_.size(),
arg.a_element_op_,
arg.b_element_op_,
PassThrough{});
// Elementwise kernels
for(int i = 0; i < arg.group_count_; ++i)
for(size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
time += launch_and_time_kernel(
stream_config,
......@@ -815,11 +819,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
arg.skipped_group_count_) != arg.group_count_)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "The group count is not equal to sum of skipped groups "
"and kernel args size!"
<< std::endl;
#endif // DEBUG_LOG
}
return false;
}
......@@ -831,11 +836,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg);
if(not group_arg_valid)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i
<< " has invalid GridwiseGemm settings!" << std::endl;
gemm_arg.Print();
#endif // DEBUG_LOG
}
}
supported = supported && group_arg_valid;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <tuple>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/stream_utility.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
///
/// @brief Entry point kernel for device-wide Grouped GEMM operation.
///
/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures.
/// @param[in] group_count The number of together processed GEMMs.
///
/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation.
/// @tparam GemmDesc The structure holding all necessary descriptors and
/// other data needed for grouped gemm calculation and work
/// distribution.
/// @tparam LocalBlock2ETileMap The structure providing mapping between workgroup ids,
/// the data tiles to process and the output tiles.
///
template <typename GridwiseGemm,
typename GemmDesc,
GemmSpecialization GemmSpec,
typename DsDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename OffsettedBlockToCTileMap,
typename LocalBlock2ETileMap,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_multiple_d_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
constexpr auto NumDTensor = DsDataType::Size();
index_t tile_id = get_block_1d_id();
index_t tile_offset = 0;
index_t group_id = -1;
index_t group_offset = 0;
index_t grid_size_grp = 0;
index_t gemm_tile_id_start = 0;
index_t gemm_tile_id_end = 0;
using AGridDescMK =
remove_cvref_t<decltype(GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(
1, 1, 1))>;
using BGridDescNK =
remove_cvref_t<decltype(GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(
1, 1, 1))>;
using EGridDescMN =
remove_cvref_t<decltype(GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
1, 1, 1))>;
using DsGridDescMN =
remove_cvref_t<decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>(
{}, {}, {}))>;
index_t M = 0, N = 0, K = 0;
index_t StrideA, StrideB, StrideE;
std::array<index_t, NumDTensor> StrideDs;
AGridDescMK a_grid_desc_mk;
BGridDescNK b_grid_desc_nk;
EGridDescMN e_grid_desc_mn;
DsGridDescMN ds_grid_desc_mn;
auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1);
do
{
// Find corresponding GEMM group for our tile
while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end) &&
group_id < group_count)
{
group_offset += grid_size_grp;
group_id++;
if(group_id >= group_count)
return;
M = gemm_desc_ptr[group_id].M;
N = gemm_desc_ptr[group_id].N;
K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0)
{
grid_size_grp = 0;
continue;
}
b2c_tile_map =
OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N), group_offset, tile_offset);
grid_size_grp = b2c_tile_map.CalculateGridSize(M, N);
gemm_tile_id_start = group_offset;
gemm_tile_id_end = group_offset + grid_size_grp;
}
StrideA = gemm_desc_ptr[group_id].StrideA;
StrideB = gemm_desc_ptr[group_id].StrideB;
StrideDs = gemm_desc_ptr[group_id].StrideDs;
StrideE = gemm_desc_ptr[group_id].StrideE;
a_grid_desc_mk =
GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
b_grid_desc_nk =
GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
e_grid_desc_mn =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
ds_grid_desc_mn(j) = GridwiseGemm::template MakeEGridDescriptor_M_N<DLayout, GemmSpec>(
M, N, StrideDs[j]);
});
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
DsGridPointer p_ds_grid;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
});
bool has_main_kblock_loop =
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_mk.GetLength(Number<1>{}));
// Update tile offset if we have moved within group
b2c_tile_map.UpdateTileOffset(tile_offset);
if(has_main_kblock_loop)
{
GridwiseGemm::template Run<true>(gemm_desc_ptr[group_id].p_a_grid,
gemm_desc_ptr[group_id].p_b_grid,
p_ds_grid,
gemm_desc_ptr[group_id].p_e_grid,
static_cast<void*>(p_shared),
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_mk,
b_grid_desc_nk,
ds_grid_desc_mn,
e_grid_desc_mn,
b2c_tile_map);
}
else
{
GridwiseGemm::template Run<false>(gemm_desc_ptr[group_id].p_a_grid,
gemm_desc_ptr[group_id].p_b_grid,
p_ds_grid,
gemm_desc_ptr[group_id].p_e_grid,
static_cast<void*>(p_shared),
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_mk,
b_grid_desc_nk,
ds_grid_desc_mn,
e_grid_desc_mn,
b2c_tile_map);
}
tile_id += get_grid_size();
tile_offset += get_grid_size();
} while(group_id < group_count);
#else
ignore = gemm_descs_const;
ignore = group_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
PipelineVersion PipelineVer = PipelineVersion::v1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
typename ComputeDataType = EDataType>
struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
: public DeviceGroupedGemmTileLoop<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGroupedGemmMultipleDXdlCShuffleTileLoop;
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
ADataType,
BDataType,
ComputeDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer>;
template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMap
{
using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t group_offset,
index_t tile_offset)
: block_to_ctile_map_{block_to_ctile_map},
group_offset_{group_offset},
tile_offset_{tile_offset}
{
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_));
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
{
return block_to_ctile_map_.CalculateGridSize(M, N);
}
__device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; }
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t group_offset_;
index_t tile_offset_;
};
using KernelArguments = GroupedGemmTileLoopKernelArguments<NumDTensor>;
using Block2ETileMap = BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock>;
using OffsetedLocalBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
// Argument
struct Argument : public BaseArgument
{
Argument(std::vector<const void*>& /* p_As */,
std::vector<const void*>& /* p_Bs */,
std::vector<std::array<const void*, NumDTensor>>& /* p_Ds */,
std::vector<void*>& /* p_Es */,
const std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
int occupancy_num_blocks,
int gpu_cu_count)
: group_count_{static_cast<index_t>(gemm_descs.size())},
occupancy_num_blocks_{occupancy_num_blocks},
gpu_cu_count_{gpu_cu_count},
gemm_descs_{gemm_descs},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
tile_count_{0}
{
for(const auto& desc : gemm_descs)
{
const auto M = desc.M_;
const auto N = desc.N_;
const auto b2c_tile_map = Block2ETileMap(M, N);
tile_count_ += b2c_tile_map.CalculateGridSize(M, N);
}
}
index_t group_count_;
const void* p_dev_gemm_args_;
int occupancy_num_blocks_;
int gpu_cu_count_;
const std::vector<GemmDesc>& gemm_descs_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
index_t tile_count_;
};
struct KernelConfig
{
// The oversubscription factor for the number of blocks that can simultaneously reside on
// GPU.
static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1;
static constexpr int BLOCK_WAVES = BlockSize / get_warp_size();
static constexpr int CU_SIMDS = 4;
// Assume we want to have at most 2 waves per SIMD
static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
};
// Invoker
struct Invoker : public BaseInvoker
{
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using user provided device buffer for kernel
/// arguments.
///
/// @param[in] arg The structure containing kernel arguments (in host
/// memory).
/// @param[in] dev_gemm_args The pointer to device memory with kernel arguments.
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float Run(const Argument& arg,
const void* dev_gemm_args,
const StreamConfig& stream_config = StreamConfig{})
{
if(dev_gemm_args == nullptr)
{
std::ostringstream err;
err << "The gemm arguments device buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
float ave_time = 0;
ave_time = DispatchKernel(arg, dev_gemm_args, stream_config);
return ave_time;
}
///
/// @brief Launch Grouped Gemm kernel.
///
/// @note This function overload is using device buffers (for kernel arguments and
/// for kernel auxiliary workspace) provided with an argument. The user should
/// call @see GetDeviceKernelArgSize, and @see SetDeviceKernelArgs, on arg
/// parameter to properly allocate those buffers.
///
/// @param[in] arg The structure containing kernel arguments (in host memory).
/// @param[in] stream_config The device stream configuration.
///
/// @return The average kernel execution time (if time measurement is enabled.)
///
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(arg.p_dev_gemm_args_ == nullptr)
{
std::ostringstream err;
err << "The gemm arguments device buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
return Run(arg, arg.p_dev_gemm_args_, stream_config);
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
private:
float DispatchKernel(const Argument& arg,
const void* dev_gemm_args,
const StreamConfig& stream_config) const
{
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
KernelArguments,
GemmSpec,
DsDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
OffsetedLocalBlock2ETileMap,
Block2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
return LaunchKernel(kernel, arg, dev_gemm_args, stream_config);
}
template <typename KernelFunction>
int CalculateMaxOccupancyGridSize(const KernelFunction& kernel,
const StreamConfig& stream_config) const
{
// Calculate max number of workgroups that can simultaneously reside on the CU.
int occ_num_blocks = 0;
size_t dyn_shared_mem_per_blk = 0;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occ_num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk));
int cu_count = getAvailableComputeUnitCount(stream_config);
if(stream_config.log_level_ > 0)
{
std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks
<< ", available CUs count: " << cu_count << ", occup. grid size: "
<< ck::math::min(occ_num_blocks, KernelConfig::CU_BLOCKS) * cu_count
<< std::endl;
}
return cu_count * ck::math::min(occ_num_blocks, KernelConfig::CU_BLOCKS);
}
template <typename KernelFunction>
float LaunchKernel(const KernelFunction& kernel,
const Argument& arg,
const void* dev_gemm_args,
const StreamConfig& stream_config) const
{
int grid_size = CalculateMaxOccupancyGridSize(kernel, stream_config);
if(stream_config.log_level_ > 0)
{
std::cout << "grid_size: " << grid_size << " tile_count: " << arg.tile_count_
<< std::endl;
}
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(dev_gemm_args),
arg.group_count_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
using DsGridDescMN = remove_cvref_t<
decltype(GridwiseGemm::template MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>(
{}, {}, {}))>;
bool supported = true;
for(const auto& gdesc : arg.gemm_descs_)
{
const auto M = gdesc.M_;
const auto N = gdesc.N_;
const auto K = gdesc.K_;
const auto StrideA = gdesc.stride_A_;
const auto StrideB = gdesc.stride_B_;
const auto StrideE = gdesc.stride_C_;
const auto& StrideDs = gdesc.stride_Ds_;
// If M dimension is unknown at launch time then validate just NK.
// If N or K dim is zero (or unknown) then the vector loads responsibility lies on
// the user.
if(N * K == 0)
continue;
const auto a_grid_desc_mk =
GridwiseGemm::template MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
const auto b_grid_desc_nk =
GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
const auto e_grid_desc_mn =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
DsGridDescMN ds_grid_desc_mn;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
ds_grid_desc_mn(j) =
GridwiseGemm::template MakeEGridDescriptor_M_N<DLayout, GemmSpec>(
M, N, StrideDs[j]);
});
const auto b2c_tile_map = Block2ETileMap(M, N);
if(!(GridwiseGemm::template CheckValidity(a_grid_desc_mk,
b_grid_desc_nk,
ds_grid_desc_mn,
e_grid_desc_mn,
b2c_tile_map) &&
GridwiseGemm::template CheckTensorTransfersValidity<ALayout, BLayout, ELayout>(
M, N, K)))
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << ","
<< K << "] are not supported by current template parameters!"
<< " In " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
}
supported = false;
}
}
return supported;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_elementwise_op,
BElementwiseOperation b_elementwise_op,
CDEElementwiseOperation cde_elementwise_op)
{
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
KernelArguments,
GemmSpec,
DsDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
OffsetedLocalBlock2ETileMap,
Block2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
return Argument{p_As,
p_Bs,
p_Ds,
p_Es,
gemm_descs,
a_elementwise_op,
b_elementwise_op,
cde_elementwise_op,
occupancy,
num_cu};
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation a_elementwise_op,
BElementwiseOperation b_elementwise_op,
CDEElementwiseOperation cde_elementwise_op) override
{
const auto kernel = kernel_grouped_gemm_multiple_d_xdl<GridwiseGemm,
KernelArguments,
GemmSpec,
DsDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
OffsetedLocalBlock2ETileMap,
Block2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
int occupancy, num_cu;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
num_cu = dev_prop.multiProcessorCount;
return std::make_unique<Argument>(p_As,
p_Bs,
p_Ds,
p_Es,
gemm_descs,
a_elementwise_op,
b_elementwise_op,
cde_elementwise_op,
occupancy,
num_cu);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::ostringstream();
// clang-format off
str << "DeviceGroupedGemmMultipleDXdlCShuffleTileLoop"
<< "<"
<< std::string(ALayout::name)[0] << ","
<< std::string(BLayout::name)[0] << ","
<< std::string(ELayout::name)[0] << ","
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< PipelineVer << ", "
<< LoopSched
<< ">";
// clang-format on
return str.str();
}
void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
}
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(KernelArguments);
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -514,7 +514,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
<< ", "
......@@ -535,7 +536,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
#endif
}
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_,
......
......@@ -529,11 +529,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
arg.skipped_group_count_) != arg.group_count_)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "The group count is not equal to sum of skipped groups "
"and kernel args size!"
<< std::endl;
#endif // DEBUG_LOG
}
return false;
}
......@@ -544,11 +545,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
if(not group_arg_valid)
{
#if DEBUG_LOG
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i
<< " has invalid GridwiseGemm settings!" << std::endl;
a.Print();
#endif // DEBUG_LOG
}
}
supported = supported && group_arg_valid;
}
......
......@@ -596,7 +596,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......@@ -958,7 +958,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
#if 0
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -8,10 +8,13 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
......@@ -36,9 +39,10 @@ struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd<DOutDataType, IndexDataTyp
using UnaryConvert = ck::tensor_operation::element_wise::UnaryConvert;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
template <typename Desc_M>
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t loop_step)
static auto PadDescriptor_M_1d(Desc_M& desc_m, index_t loop_step)
{
const auto m = desc_m.GetLength(I0);
const auto pad = math::integer_least_multiple(m, loop_step) - m;
......@@ -56,7 +60,18 @@ struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd<DOutDataType, IndexDataTyp
return PadDescriptor_M_1d(desc_m, loop_step);
}
template <typename Desc_M>
static auto ExpendDescFirstDim(Desc_M desc_m)
{
return transform_tensor_descriptor(
desc_m,
make_tuple(make_unmerge_transform(make_tuple(I1, desc_m.GetLength(I0)))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{}));
}
using InOutGrid1dDesc = decltype(MakeDescriptor_M(1, 1));
using InOutGrid2dDesc = decltype(ExpendDescFirstDim(InOutGrid1dDesc{}));
using GridwisePutElementSet = GridwisePutElement_1D<InOutGrid1dDesc,
DOutDataType,
......@@ -74,14 +89,30 @@ struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd<DOutDataType, IndexDataTyp
InMemoryDataOperationEnum::AtomicAdd,
InOutVectorSize>;
using GridwiseCasting = GridwiseElementwise_1D<Tuple<InOutGrid1dDesc>,
Tuple<InOutGrid1dDesc>,
static constexpr index_t BlockSize = 256;
static constexpr index_t MPerThread = 1;
static constexpr index_t NPerThread = InOutVectorSize;
static constexpr index_t MPerBlock = 1;
static constexpr index_t NPerBlock = BlockSize * NPerThread;
using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
using GridwiseCasting = GridwiseElementwise<Tuple<InOutGrid2dDesc>,
Tuple<InOutGrid2dDesc>,
Tuple<const DInDataType_AutomicAddPreCast*>,
Tuple<DInDataType*>,
Block2TileMap,
UnaryConvert,
InOutVectorSize,
BlockSize,
MPerBlock,
NPerBlock,
MPerThread,
NPerThread,
Sequence<0, 1>,
Sequence<InOutVectorSize>,
Sequence<InOutVectorSize>,
Sequence<InOutVectorSize>>;
I1,
I1>;
struct Argument : public BaseArgument
{
......@@ -98,7 +129,7 @@ struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd<DOutDataType, IndexDataTyp
p_din_{p_din},
dout_length_raw_{dout_length},
din_length_raw_{din_length},
blockSize_{256},
blockSize_{BlockSize},
windowOverlap_{false}
{
for(size_t i = 0; i < window_lengths.size(); ++i)
......@@ -195,11 +226,12 @@ struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd<DOutDataType, IndexDataTyp
PassThrough>;
const auto cast_kernel =
kernel_elementwise_1d<GridwiseCasting,
Tuple<InOutGrid1dDesc>,
Tuple<InOutGrid1dDesc>,
kernel_elementwise<GridwiseCasting,
Tuple<InOutGrid2dDesc>,
Tuple<InOutGrid2dDesc>,
Tuple<const DInDataType_AutomicAddPreCast*>,
Tuple<DInDataType*>,
Block2TileMap,
UnaryConvert>;
float elapsed_time = launch_and_time_kernel(
......@@ -214,16 +246,25 @@ struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd<DOutDataType, IndexDataTyp
static_cast<DInDataType_AutomicAddPreCast*>(arg.p_workspace_),
PassThrough{});
InOutGrid2dDesc din_grid_desc_2d = ExpendDescFirstDim(din_grid_desc);
const index_t M = din_grid_desc_2d.GetLength(I0);
const index_t N = din_grid_desc_2d.GetLength(I1);
const auto block_2_tile_map = Block2TileMap(M, N);
const auto cast_kernel_grid_size =
block_2_tile_map.CalculateGridSize(din_grid_desc_2d);
elapsed_time += launch_and_time_kernel(
stream_config,
cast_kernel,
dim3(gridSize),
dim3(cast_kernel_grid_size),
dim3(arg.blockSize_),
0,
ck::make_tuple(din_grid_desc),
ck::make_tuple(din_grid_desc),
static_cast<DInDataType_AutomicAddPreCast*>(arg.p_workspace_),
arg.p_din_,
ck::make_tuple(din_grid_desc_2d),
ck::make_tuple(din_grid_desc_2d),
ck::make_tuple(
static_cast<const DInDataType_AutomicAddPreCast*>(arg.p_workspace_)),
ck::make_tuple(arg.p_din_),
block_2_tile_map,
UnaryConvert{});
return elapsed_time;
......
......@@ -594,7 +594,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
static bool IsSupportedArgument(const RawArg& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......@@ -950,7 +950,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
#if 0
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_navi3_supported())
if(ck::is_gfx11_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
......
......@@ -11,7 +11,6 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
namespace ck {
......
......@@ -4,7 +4,7 @@
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
......@@ -179,6 +179,16 @@ struct Multiply
y = ck::type_convert<bhalf_t>(y_tmp);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const int8_t& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x0);
const float x2_tmp = ck::type_convert<float>(x1);
const float y_tmp = x1_tmp * x2_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
......@@ -485,6 +495,19 @@ struct AddFastGelu
e = type_convert<half_t>(x1_f);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
{
const float x0_f = type_convert<float>(c) + type_convert<float>(d);
float x1_f = 0;
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
e = type_convert<bhalf_t>(x1_f);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
......@@ -499,6 +522,71 @@ struct AddFastGelu
}
};
// E = MultiplyFastGelu(C + D)
struct MultiplyFastGelu
{
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d) const
{
const float x = c * d;
FastGelu{}.template operator()<float, float>(e, x);
}
template <>
__host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
{
const half_t x = c * d;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
}
template <>
__host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
{
const float x0_f = c * d;
float x1_f = 0;
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
x0_f);
e = type_convert<half_t>(x1_f);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
{
const float x0_f = type_convert<float>(c) * type_convert<float>(d);
float x1_f = 0;
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
e = type_convert<bhalf_t>(x1_f);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
{
const float x0_f = c * type_convert<float>(d);
float x1_f = 0;
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
e = type_convert<bhalf_t>(x1_f);
}
};
// E = Silu(C + D)
struct AddSilu
{
......
......@@ -14,6 +14,8 @@ namespace element_wise {
template <typename... UnaryOpsSet>
struct UnaryCombinedOp
{
__host__ __device__ UnaryCombinedOp() : unary_ops_() {}
__host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {}
template <typename Y, typename X>
......@@ -32,6 +34,8 @@ struct UnaryCombinedOp
template <typename BinaryOp, typename UnaryOp0, typename UnaryOp1>
struct BinaryWithUnaryCombinedOp
{
__host__ __device__ BinaryWithUnaryCombinedOp() : binary_op_(), unary_op0_(), unary_op1_() {}
__host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op,
UnaryOp0 unary_op0,
UnaryOp1 unary_op1)
......@@ -63,6 +67,11 @@ template <typename BinaryOp0,
typename UnaryOp2>
struct TrinaryWithUnaryCombinedOp
{
__host__ __device__ TrinaryWithUnaryCombinedOp()
: binary_op0_(), binary_op1_(), unary_op0_(), unary_op1_(), unary_op2_()
{
}
__host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0,
BinaryOp0 binary_op1,
UnaryOp0 unary_op0,
......
......@@ -221,6 +221,15 @@ struct MultiplyAdd
e = y;
}
template <>
__host__ __device__ void operator()<bhalf_t, float, bhalf_t, bhalf_t>(bhalf_t& e,
const float& c,
const bhalf_t& d0,
const bhalf_t& d1) const
{
const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
e = y;
}
template <>
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
const float& c,
const half_t& d0,
......@@ -240,6 +249,26 @@ struct MultiplyAdd
}
};
struct MultiplyAddFastGelu
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<ck::bhalf_t, float, ck::bhalf_t, ck::bhalf_t>(
ck::bhalf_t& e, const float& c, const ck::bhalf_t& d0, const ck::bhalf_t& d1) const
{
const float x0_f = c * ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
float x1_f = 0;
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
e = ck::type_convert<ck::bhalf_t>(x1_f);
}
};
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
......@@ -499,6 +528,26 @@ struct UnaryTypeConvert<ck::bhalf_t, float>
}
};
struct ConvInvscale
{
/// @brief Op to multiply convolution results by inverted scale factors
/// @param e Output after scaling
/// @param c Convolution result
/// @param d0 Input scale factor
/// @param d1 Weights scale factor
/// @param d2 Output scale factor
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
template <>
__host__ __device__ void operator()<f8_t, float, float, float, float>(
f8_t& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
e = type_convert<f8_t>(c / d0 / d1 / d2);
};
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
......@@ -22,6 +22,7 @@ struct PassThroughPack2
auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t);
}
constexpr const static bool is_pack2_invocable = true;
};
struct PassThrough
......@@ -131,12 +132,24 @@ struct PassThrough
y = type_convert<int8_t>(x);
}
template <>
__host__ __device__ void operator()<int32_t, int8_t>(int32_t& y, const int8_t& x) const
{
y = type_convert<int32_t>(x);
}
template <>
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
{
y = type_convert<int8_t>(x);
}
template <>
__host__ __device__ void operator()<float, int8_t>(float& y, const int8_t& x) const
{
y = type_convert<float>(x);
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
__host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
......@@ -275,10 +288,13 @@ struct ConvertF8RNE
struct Scale
{
__host__ __device__ Scale(float scale) : scale_(scale) {}
__host__ __device__ Scale(float scale = 1.f) : scale_(scale) {}
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
__host__ __device__ void operator()(Y& y, const X& x) const
{
y = ck::type_convert<Y>(ck::type_convert<float>(x) * scale_);
}
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
......@@ -487,6 +503,46 @@ struct FastGelu
y = type_convert<half_t>(y_f);
}
template <>
__host__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<bhalf_t>(y_f);
}
template <>
__device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<bhalf_t>(y_f);
}
template <>
__device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<bhalf_t>(y_f);
}
template <>
__host__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<bhalf_t>(y_f);
}
};
// https://paperswithcode.com/method/gelu
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -151,7 +151,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
{
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
......@@ -259,49 +259,23 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
BlockToCTileMap_M00_N0_M01Adapt;
};
// Rows of column-vectors
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt;
// Grouped Rows of column-vectors WGP mapping
// Optimized for gfx94x-like multipe-die chip
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock>
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(
const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default;
__host__ __device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt&
operator=(const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt&
operator=(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M,
index_t N,
index_t M01 = 8)
: M_(M), N_(N), M01_(M01)
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
}
#endif
}
template <typename CGridDesc_M_N>
__host__ __device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01 = 8)
: BlockToCTileMap_Grouped_M00_N0_M01Adapt(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
{
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
......@@ -309,12 +283,6 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, v
return M0 * N0;
}
template <typename CGridDesc_M_N>
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
......@@ -329,11 +297,25 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, v
const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
if(M0 == 1)
{
return make_tuple(0, block_1d_id);
}
else if(N0 == 1)
{
return make_tuple(block_1d_id, 0);
}
// block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
else
{
const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum);
auto group_id = block_1d_id % GroupNum;
auto remap_block_1d_id = group_id * group_size + block_1d_id / GroupNum;
const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
auto group_id_x = block_1d_id % GroupNum;
auto group_id_y = block_1d_id / GroupNum;
auto remap_block_1d_id =
group_id_x <= big_group_num
? group_id_x * group_size + group_id_y
: group_id_x * group_size + big_group_num - group_id_x + group_id_y;
index_t idx_N0 = remap_block_1d_id % N0;
index_t idx_M0 = remap_block_1d_id / N0;
......@@ -391,6 +373,7 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, v
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
......@@ -405,15 +388,6 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, v
index_t M01_;
};
// keep the redundant type argument for backward compatibility
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
: BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>
{
using BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>::
BlockToCTileMap_Grouped_M00_N0_M01Adapt;
};
// columns of row-vectors
// This C-tile map dynamically adjusts N01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
......@@ -454,7 +428,7 @@ struct BlockToCTileMap_N00_M0_N01Adapt<MPerBlock, NPerBlock, void>
{
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
......@@ -926,6 +900,11 @@ struct OffsettedBlockToCTileMap
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
}
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
{
return block_to_ctile_map_.CalculateGridSize(M, N);
}
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_;
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseElementwise1dFunctor,
typename InGrid1dDescTuple,
typename OutGrid1dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation>
__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple,
const OutGrid1dDescTuple out_grid_1d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op)
{
GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple,
out_grid_1d_desc_tuple,
p_in_global_tuple,
p_out_global_tuple,
elementwise_op);
}
template <typename InGrid1dDescTuple,
typename OutGrid1dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
index_t MPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct GridwiseElementwise_1D
{
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size() &&
NumInput == InGrid1dDescTuple::Size() &&
NumOutput == OutGrid1dDescTuple::Size(),
"Tuple size is inconsistent with the number of in/out!");
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
__device__ static void Run(const InGrid1dDescTuple in_grid_1d_desc_tuple,
const OutGrid1dDescTuple out_grid_1d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op)
{
const index_t thread_global_id = get_thread_global_1d_id();
auto in_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, MPerThread, true>{};
},
Number<NumInput>{});
auto out_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return StaticBuffer<AddressSpaceEnum::Vgpr, DataType, MPerThread, true>{};
},
Number<NumOutput>{});
auto in_global_buf_tuple = generate_tuple(
[&](auto I) {
static_assert(in_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumInput>{});
auto out_global_buf_tuple = generate_tuple(
[&](auto I) {
static_assert(out_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumOutput>{});
const auto thread_global_offset = make_multi_index(thread_global_id * MPerThread);
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto M = in_grid_1d_desc_tuple[I0].GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * MPerThread;
const auto loop_step_index = make_multi_index(loop_step);
auto in_global_load_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return ThreadwiseTensorSliceTransfer_v2<DataType,
DataType,
decltype(in_grid_1d_desc_tuple[I]),
decltype(thread_buffer_desc_m),
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
InScalarPerVectorSeq::At(
I), // ScalarPerVector
1, // SrcScalarStrideInVector
false>{in_grid_1d_desc_tuple[I],
thread_global_offset};
},
Number<NumInput>{});
auto out_global_store_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return ThreadwiseTensorSliceTransfer_v1r3<DataType,
DataType,
decltype(thread_buffer_desc_m),
decltype(out_grid_1d_desc_tuple[I]),
PassThroughOp,
Sequence<MPerThread>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
OutScalarPerVectorSeq::At(I),
InMemoryDataOperationEnum::Set,
1,
false>(
out_grid_1d_desc_tuple[I], thread_global_offset, PassThroughOp{});
},
Number<NumOutput>{});
index_t num_iter = M / (loop_step);
do
{
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).Run(in_grid_1d_desc_tuple[I],
in_global_buf_tuple[I],
thread_buffer_desc_m,
make_tuple(I0),
in_thread_buf_tuple(I));
in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_1d_desc_tuple[I],
loop_step_index);
});
static_for<0, MPerThread, 1>{}([&](auto iM) {
// get reference to in data
const auto in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); },
Number<NumInput>{});
// get reference to dst data
auto out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& { return out_thread_buf_tuple(I)(iM); },
Number<NumOutput>{});
unpack2(elementwise_op, out_data_refs, in_data_refs);
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).Run(thread_buffer_desc_m,
make_tuple(I0),
out_thread_buf_tuple[I],
out_grid_1d_desc_tuple[I],
out_global_buf_tuple(I));
out_global_store_tuple(I).MoveDstSliceWindow(out_grid_1d_desc_tuple[I],
loop_step_index);
});
} while(--num_iter);
}
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......
// SPDX-License-Identifier: MIT
// // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
//
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/utility/common_header.hpp"
namespace ck {
template <typename GridwiseElementwise2dFunctor,
typename InGrid2dDescTuple,
typename OutGrid2dDescTuple,
template <typename GridwiseElementwiseFunctor,
typename InGridDescTuple,
typename OutGridDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename Block2TileMap,
typename ElementwiseOperation>
__global__ void kernel_elementwise_2d(const InGrid2dDescTuple in_grid_2d_desc_tuple,
const OutGrid2dDescTuple out_grid_2d_desc_tuple,
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_elementwise(const InGridDescTuple in_grid_desc_tuple,
const OutGridDescTuple out_grid_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n)
const Block2TileMap block_2_tile_map,
const ElementwiseOperation elementwise_op)
{
GridwiseElementwise2dFunctor::Run(in_grid_2d_desc_tuple,
out_grid_2d_desc_tuple,
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
out_grid_desc_tuple,
p_in_global_tuple,
p_out_global_tuple,
elementwise_op,
num_threads_m,
num_threads_n);
block_2_tile_map,
elementwise_op);
}
template <typename InGrid2dDescTuple,
typename OutGrid2dDescTuple,
template <typename GridwiseElementwiseFunctor,
typename InGridDescTuple,
typename OutGridDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename Block2TileMap,
typename ElementwiseOperation,
index_t MPerThread,
index_t NPerThread,
index_t NumInputs,
index_t NumOutputs>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple,
const OutGridDescTuple out_grid_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const Block2TileMap block_2_tile_map,
const ElementwiseOperation elementwise_op,
const index_t batch_count,
const std::array<index_t, NumInputs> input_batch_strides,
const std::array<index_t, NumOutputs> output_batch_strides)
{
static_assert(InGridDescTuple::Size() == NumInputs &&
InDataTypePointerTuple::Size() == NumInputs);
static_assert(OutGridDescTuple::Size() == NumOutputs &&
OutDataTypePointerTuple::Size() == NumOutputs);
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
InDataTypePointerTuple p_in_global_with_offset_tuple;
OutDataTypePointerTuple p_out_global_with_offset_tuple;
static_for<0, InDataTypePointerTuple::Size(), 1>{}([&](auto i) {
p_in_global_with_offset_tuple(i) = p_in_global_tuple.At(i) + input_batch_strides[i] * g_idx;
});
static_for<0, OutDataTypePointerTuple::Size(), 1>{}([&](auto i) {
p_out_global_with_offset_tuple(i) =
p_out_global_tuple.At(i) + output_batch_strides[i] * g_idx;
});
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
out_grid_desc_tuple,
p_in_global_with_offset_tuple,
p_out_global_with_offset_tuple,
block_2_tile_map,
elementwise_op);
}
template <typename InGridDescTuple,
typename OutGridDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename Block2TileMap,
typename ElementwiseOperation,
index_t BlockSize,
index_t M0PerBlock,
index_t M1PerBlock,
index_t M0PerThread,
index_t M1PerThread,
typename ThreadClusterArrangeOrder,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct GridwiseElementwise_2D
typename OutScalarPerVectorSeq,
index_t SrcVectorDim,
index_t DstVectorDim>
struct GridwiseElementwise
{
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size() &&
NumInput == InGrid2dDescTuple::Size() &&
NumOutput == OutGrid2dDescTuple::Size(),
NumInput == InGridDescTuple::Size() && NumOutput == OutGridDescTuple::Size(),
"Tuple size is inconsistent with the number of in/out!");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto thread_buffer_desc_mn =
make_naive_tensor_descriptor_packed(make_tuple(Number<MPerThread>{}, Number<NPerThread>{}));
static_assert((SrcVectorDim == I0 || SrcVectorDim == I1) &&
(DstVectorDim == I0 || DstVectorDim == I1),
"Vector dim must be equal to 0 or 1.");
using PassThroughOp = tensor_operation::element_wise::PassThrough;
__device__ static void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple,
const OutGrid2dDescTuple out_grid_2d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n)
__device__ static void Run(const InGridDescTuple& in_grid_desc_tuple,
const OutGridDescTuple& out_grid_desc_tuple,
const InDataTypePointerTuple& p_in_global_tuple,
const OutDataTypePointerTuple& p_out_global_tuple,
const Block2TileMap& block_2_tile_map,
const ElementwiseOperation& elementwise_op)
{
auto in_thread_buf_tuple = generate_tuple(
constexpr auto src_datas = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
MPerThread * NPerThread,
true>{};
return DataType{};
},
Number<NumInput>{});
auto out_thread_buf_tuple = generate_tuple(
constexpr auto dst_datas = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
MPerThread * NPerThread,
true>{};
return DataType{};
},
Number<NumOutput>{});
auto in_global_buf_tuple = generate_tuple(
const auto in_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_2d_desc_tuple[I].GetElementSpaceSize());
p_in_global_tuple[I], in_grid_desc_tuple[I].GetElementSpaceSize());
},
Number<NumInput>{});
auto out_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global_tuple[I], out_grid_2d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumOutput>{});
const auto M = in_grid_2d_desc_tuple[I0].GetLength(I0);
const auto N = in_grid_2d_desc_tuple[I0].GetLength(I1);
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const index_t thread_1d_id = get_thread_global_1d_id();
index_t tid_m = thread_1d_id / num_threads_n;
index_t tid_n = thread_1d_id % num_threads_n;
const auto thread_global_offset = make_multi_index(tid_m * MPerThread, tid_n * NPerThread);
auto in_global_load_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return ThreadwiseTensorSliceTransfer_v2<
DataType,
DataType,
decltype(in_grid_2d_desc_tuple[I]),
decltype(thread_buffer_desc_mn),
Sequence<MPerThread, NPerThread>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
0, // SrcVectorDim
InScalarPerVectorSeq::At(I), // ScalarPerVector
1, // SrcScalarStrideInVector
true>{in_grid_2d_desc_tuple[I], thread_global_offset};
},
Number<NumInput>{});
auto out_global_store_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return ThreadwiseTensorSliceTransfer_v1r3<
DataType,
DataType,
decltype(thread_buffer_desc_mn),
decltype(out_grid_2d_desc_tuple[I]),
PassThroughOp,
Sequence<MPerThread, NPerThread>, // SliceLengths
Sequence<0, 1>, // DimAccessOrder
1, // SrcVectorDim
1, // OutScalarPerVectorSeq::At(I),
InMemoryDataOperationEnum::Set,
1,
true>(out_grid_2d_desc_tuple[I], thread_global_offset, PassThroughOp{});
p_out_global_tuple[I], out_grid_desc_tuple[I].GetElementSpaceSize());
},
Number<NumOutput>{});
index_t num_iter_m = M / (loop_step_m);
do
{
index_t num_iter_n = N / (loop_step_n);
do
{
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).Run(in_grid_2d_desc_tuple[I],
in_global_buf_tuple[I],
thread_buffer_desc_mn,
make_tuple(I0, I0),
in_thread_buf_tuple(I));
in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_2d_desc_tuple[I],
make_multi_index(0, loop_step_n));
});
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
static_for<0, MPerThread, 1>{}([&](auto iM) {
static_for<0, NPerThread, 1>{}([&](auto iN) {
constexpr auto offset =
thread_buffer_desc_mn.CalculateOffset(make_tuple(iM, iN));
// get reference to in data
const auto in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> const auto& {
return in_thread_buf_tuple(I)(Number<offset>{});
const index_t m0_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);
const index_t m1_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock);
const auto input_thread_grid_offset = generate_tuple(
[&](auto) {
return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
},
Number<NumInput>{});
// get referenec to dst data
auto out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& {
return out_thread_buf_tuple(I)(Number<offset>{});
const auto output_thread_grid_offset = generate_tuple(
[&](auto) {
return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
},
Number<NumOutput>{});
unpack2(elementwise_op, out_data_refs, in_data_refs);
});
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).Run(thread_buffer_desc_mn,
make_tuple(I0, I0),
out_thread_buf_tuple[I],
out_grid_2d_desc_tuple[I],
out_global_buf_tuple(I));
out_global_store_tuple(I).MoveDstSliceWindow(out_grid_2d_desc_tuple[I],
make_multi_index(0, loop_step_n));
});
} while(--num_iter_n);
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_2d_desc_tuple[I],
make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n));
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_2d_desc_tuple[I],
make_multi_index(loop_step_m, -(N / loop_step_n) * loop_step_n));
});
} while(--num_iter_m);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// If src and dst have same vector dim, then:
// M0 dim - for src and dst vector load/store
// else:
// M0 dim - for dst vector load
// M1 dim - for src vector store
using SrcDimAccessOrder =
std::conditional_t<SrcVectorDim == I1, Sequence<0, 1>, Sequence<1, 0>>;
using DstDimAccessOrder =
std::conditional_t<DstVectorDim == I1, Sequence<0, 1>, Sequence<1, 0>>;
using ThreadClusterLengths =
Sequence<Number<M0PerBlock / M0PerThread>{}, Number<M1PerBlock / M1PerThread>{}>;
auto global_to_global_transfer = ThreadGroupTensorSliceTransfer_v4r2<
ThisThreadBlock,
ElementwiseOperation,
uniform_sequence_gen_t<NumOutput, static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
Sequence<M0PerBlock, M1PerBlock>,
ThreadClusterLengths,
ThreadClusterArrangeOrder,
decltype(src_datas),
decltype(dst_datas),
InGridDescTuple,
OutGridDescTuple,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
InScalarPerVectorSeq,
OutScalarPerVectorSeq,
uniform_sequence_gen_t<NumInput, 1>,
uniform_sequence_gen_t<NumOutput, 1>,
uniform_sequence_gen_t<NumInput, false>,
uniform_sequence_gen_t<NumOutput, false>>{in_grid_desc_tuple,
input_thread_grid_offset,
out_grid_desc_tuple,
output_thread_grid_offset,
elementwise_op};
global_to_global_transfer.Run(
in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0);
}
};
......
// SPDX-License-Identifier: MIT
// // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
//
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwiseElementwise3dFunctor,
typename InGrid3dDescTuple,
typename OutGrid3dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation>
__global__ void kernel_elementwise_3d(const InGrid3dDescTuple in_grid_3d_desc_tuple,
const OutGrid3dDescTuple out_grid_3d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n,
const index_t num_threads_k)
{
GridwiseElementwise3dFunctor::Run(in_grid_3d_desc_tuple,
out_grid_3d_desc_tuple,
p_in_global_tuple,
p_out_global_tuple,
elementwise_op,
num_threads_m,
num_threads_n,
num_threads_k);
}
template <typename InGrid3dDescTuple,
typename OutGrid3dDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename ElementwiseOperation,
index_t MPerThread,
index_t NPerThread,
index_t KPerThread,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct GridwiseElementwise_3D
{
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size() &&
NumInput == InGrid3dDescTuple::Size() &&
NumOutput == OutGrid3dDescTuple::Size(),
"Tuple size is inconsistent with the number of in/out!");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto thread_buffer_desc_mnk = make_naive_tensor_descriptor_packed(
make_tuple(Number<MPerThread>{}, Number<NPerThread>{}, Number<KPerThread>{}));
using PassThroughOp = tensor_operation::element_wise::PassThrough;
__device__ static void Run(const InGrid3dDescTuple in_grid_3d_desc_tuple,
const OutGrid3dDescTuple out_grid_3d_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const ElementwiseOperation elementwise_op,
const index_t num_threads_m,
const index_t num_threads_n,
const index_t num_threads_k)
{
auto in_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
MPerThread * NPerThread * KPerThread,
true>{};
},
Number<NumInput>{});
auto out_thread_buf_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return StaticBuffer<AddressSpaceEnum::Vgpr,
DataType,
MPerThread * NPerThread * KPerThread,
true>{};
},
Number<NumOutput>{});
auto in_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_3d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumInput>{});
auto out_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global_tuple[I], out_grid_3d_desc_tuple[I].GetElementSpaceSize());
},
Number<NumOutput>{});
const auto M = in_grid_3d_desc_tuple[I0].GetLength(I0);
const auto N = in_grid_3d_desc_tuple[I0].GetLength(I1);
const auto K = in_grid_3d_desc_tuple[I0].GetLength(I2);
const index_t loop_step_m = num_threads_m * MPerThread;
const index_t loop_step_n = num_threads_n * NPerThread;
const index_t loop_step_k = num_threads_k * KPerThread;
const index_t thread_1d_id = get_thread_global_1d_id();
const index_t tid_m = thread_1d_id / (num_threads_n * num_threads_k);
const index_t tid_nk = thread_1d_id % (num_threads_n * num_threads_k);
const index_t tid_n = tid_nk / num_threads_k;
const index_t tid_k = tid_nk % num_threads_k;
const auto thread_global_offset =
make_multi_index(tid_m * MPerThread, tid_n * NPerThread, tid_k * KPerThread);
auto in_global_load_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return ThreadwiseTensorSliceTransfer_v2<
DataType,
DataType,
decltype(in_grid_3d_desc_tuple[I]),
decltype(thread_buffer_desc_mnk),
Sequence<MPerThread, NPerThread, KPerThread>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
01, // SrcVectorDim
InScalarPerVectorSeq::At(I), // InScalarPerVectorSeq::At(I), //
// ScalarPerVector
1, // SrcScalarStrideInVector
true>{in_grid_3d_desc_tuple[I], thread_global_offset};
},
Number<NumInput>{});
auto out_global_store_tuple = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return ThreadwiseTensorSliceTransfer_v1r3<
DataType,
DataType,
decltype(thread_buffer_desc_mnk),
decltype(out_grid_3d_desc_tuple[I]),
PassThroughOp,
Sequence<MPerThread, NPerThread, KPerThread>, // SliceLengths
Sequence<0, 1, 2>, // DimAccessOrder
2, // SrcVectorDim
OutScalarPerVectorSeq::At(I), // OutScalarPerVectorSeq::At(I),
InMemoryDataOperationEnum::Set,
1,
true>(out_grid_3d_desc_tuple[I], thread_global_offset, PassThroughOp{});
},
Number<NumOutput>{});
index_t num_iter_m = M / (loop_step_m);
do
{
index_t num_iter_n = N / (loop_step_n);
do
{
index_t num_iter_k = K / (loop_step_k);
do
{
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).Run(in_grid_3d_desc_tuple[I],
in_global_buf_tuple[I],
thread_buffer_desc_mnk,
make_tuple(I0, I0, I0),
in_thread_buf_tuple(I));
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_3d_desc_tuple[I], make_multi_index(0, 0, loop_step_k));
});
static_for<0, MPerThread, 1>{}([&](auto iM) {
static_for<0, NPerThread, 1>{}([&](auto iN) {
static_for<0, KPerThread, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_mnk.CalculateOffset(make_tuple(iM, iN, iK));
// get reference to in data
const auto in_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> const auto& {
return in_thread_buf_tuple(I)(Number<offset>{});
},
Number<NumInput>{});
// get referenec to dst data
auto out_data_refs = generate_tie(
// return type should be lvalue
[&](auto I) -> auto& {
return out_thread_buf_tuple(I)(Number<offset>{});
},
Number<NumOutput>{});
unpack2(elementwise_op, out_data_refs, in_data_refs);
});
});
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).Run(thread_buffer_desc_mnk,
make_tuple(I0, I0, I0),
out_thread_buf_tuple[I],
out_grid_3d_desc_tuple[I],
out_global_buf_tuple(I));
out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_3d_desc_tuple[I], make_multi_index(0, 0, loop_step_k));
});
} while(--num_iter_k);
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_3d_desc_tuple[I],
make_multi_index(0, loop_step_n, -(K / loop_step_k) * loop_step_k));
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_3d_desc_tuple[I],
make_multi_index(0, loop_step_n, -(K / loop_step_k) * loop_step_k));
});
} while(--num_iter_n);
static_for<0, NumInput, 1>{}([&](auto I) {
in_global_load_tuple(I).MoveSrcSliceWindow(
in_grid_3d_desc_tuple[I],
make_multi_index(loop_step_m,
-(N / loop_step_n) * loop_step_n,
-(K / loop_step_k) * loop_step_k));
});
static_for<0, NumOutput, 1>{}([&](auto I) {
out_global_store_tuple(I).MoveDstSliceWindow(
out_grid_3d_desc_tuple[I],
make_multi_index(loop_step_m,
-(N / loop_step_n) * loop_step_n,
-(K / loop_step_k) * loop_step_k));
});
} while(--num_iter_m);
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment