Commit 326d6bc6 authored by Jing Zhang's avatar Jing Zhang
Browse files

move all arguments into device

parent 03d3395b
...@@ -66,7 +66,7 @@ else() ...@@ -66,7 +66,7 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-reserved-identifier -Wno-reserved-identifier
-Werror #-Werror
-Wno-option-ignored -Wno-option-ignored
-Wsign-compare -Wsign-compare
-Wno-extra-semi-stmt -Wno-extra-semi-stmt
......
...@@ -46,7 +46,7 @@ using AElementOp = PassThrough; ...@@ -46,7 +46,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CDEElementOp = PassThrough; using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
// clang-format off // clang-format off
...@@ -59,4 +59,40 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl ...@@ -59,4 +59,40 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
#include "run_grouped_gemm_example.inc" #include "run_grouped_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } // int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
problem_size.group_count = 16;
problem_size.Ms = {
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ns.push_back(768);
problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}
return !run_grouped_gemm(problem_size, config);
}
...@@ -24,6 +24,8 @@ namespace device { ...@@ -24,6 +24,8 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename GemmDesc, typename GemmDesc,
GemmSpecialization GemmSpec,
typename Block2ETileMapKSplit,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
...@@ -34,6 +36,7 @@ __global__ void ...@@ -34,6 +36,7 @@ __global__ void
#endif #endif
kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count, const index_t group_count,
const index_t grid_size_grp,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op) const CDEElementwiseOperation c_element_op)
...@@ -47,6 +50,7 @@ __global__ void ...@@ -47,6 +50,7 @@ __global__ void
const auto gemm_desc_ptr = const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const)); reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
#if 0
index_t left = 0; index_t left = 0;
index_t right = group_count; index_t right = group_count;
index_t group_id = index_t((left + right) / 2); index_t group_id = index_t((left + right) / 2);
...@@ -64,7 +68,14 @@ __global__ void ...@@ -64,7 +68,14 @@ __global__ void
} }
group_id = index_t((left + right) / 2); group_id = index_t((left + right) / 2);
} }
#endif
const index_t group_id = block_id / grid_size_grp;
if(group_id >= group_count)
return;
#if 0
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_ptr[group_id].a_ptr_, gemm_desc_ptr[group_id].a_ptr_,
gemm_desc_ptr[group_id].b_ptr_, gemm_desc_ptr[group_id].b_ptr_,
...@@ -79,6 +90,56 @@ __global__ void ...@@ -79,6 +90,56 @@ __global__ void
gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].block_2_etile_map_); gemm_desc_ptr[group_id].block_2_etile_map_);
#else
const index_t M = gemm_desc_ptr[group_id].M;
const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K;
if(M == 0 || N == 0 || K == 0)
return;
const index_t StrideA = K;
const index_t StrideB = K;
const index_t StrideDs[] = {};
const index_t StrideE = N;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
const auto e_grid_desc_m_n =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
const index_t BlockStart = group_id * grid_size_grp;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
const auto local_b2e_tile_map = Block2ETileMapKSplit{e_grid_desc_m_n};
const auto block_2_etile_map = GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart);
GridwiseGemm::template Run<HasMainKBlockLoop, GemmSpec, ALayout, BLayout, DsLayout, ELayout>(
gemm_desc_ptr[group_id].a_ptr_,
gemm_desc_ptr[group_id].b_ptr_,
gemm_desc_ptr[group_id].ds_ptr_,
gemm_desc_ptr[group_id].e_ptr_,
p_shared,
a_element_op,
b_element_op,
c_element_op,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
block_2_etile_map);
#endif
#else #else
ignore = gemm_descs_const; ignore = gemm_descs_const;
ignore = group_count; ignore = group_count;
...@@ -281,46 +342,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -281,46 +342,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
struct GroupedGemmBlock2ETileMap using Block2ETileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>;
{ using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
GroupedGemmBlock2ETileMap()
{
block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{});
BlockStart_ = -1;
}
GroupedGemmBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart)
{
block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
BlockStart_ = BlockStart;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return block_2_etile_map_.CalculateBottomIndex(
make_multi_index(idx_top[I0] - BlockStart_));
}
// it's actually E-Tile
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
__host__ bool CheckValidity(const EGridDesc_M_N& e_grid_desc_m_n) const
{
return block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
}
Block2ETileMap block_2_etile_map_;
ck::index_t BlockStart_;
};
struct GemmBiasTransKernelArg struct GemmBiasTransKernelArg
{ {
...@@ -330,6 +353,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -330,6 +353,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
typename GridwiseGemm::DsGridPointer ds_ptr_; typename GridwiseGemm::DsGridPointer ds_ptr_;
EDataType* e_ptr_; EDataType* e_ptr_;
index_t M, N, K;
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
...@@ -344,7 +369,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -344,7 +369,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map // block-to-e-tile map
GroupedGemmBlock2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
ck::index_t BlockStart_, BlockEnd_; ck::index_t BlockStart_, BlockEnd_;
}; };
...@@ -374,7 +399,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -374,7 +399,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
gemm_desc_kernel_arg_.reserve(group_count_); gemm_desc_kernel_arg_.reserve(group_count_);
skipped_group_count_ = 0; index_t group_id = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
...@@ -385,12 +410,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -385,12 +410,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
a_mtx_mraw_kraw_.emplace_back(M, K); a_mtx_mraw_kraw_.emplace_back(M, K);
b_mtx_nraw_kraw_.emplace_back(N, K); b_mtx_nraw_kraw_.emplace_back(N, K);
if(M == 0)
{
skipped_group_count_++;
continue;
}
const index_t StrideA = gemm_descs[i].stride_A_; const index_t StrideA = gemm_descs[i].stride_A_;
const index_t StrideB = gemm_descs[i].stride_B_; const index_t StrideB = gemm_descs[i].stride_B_;
const index_t StrideC = gemm_descs[i].stride_C_; const index_t StrideC = gemm_descs[i].stride_C_;
...@@ -427,24 +446,23 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -427,24 +446,23 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
const index_t grid_size_grp = // block-to-e-tile map
GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0) const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n};
.block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n);
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
std::cout << "grp id: " << group_id << " grid_size: " << grid_size_grp << std::endl;
const index_t BlockStart = grid_size_; const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp; const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
// block-to-e-tile map
const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(e_grid_desc_m_n, BlockStart);
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k, b_grid_desc_n_k,
ds_grid_desc_m_n, ds_grid_desc_m_n,
e_grid_desc_m_n, e_grid_desc_m_n,
block_2_etile_map)) local_b2c_tile_map))
{ {
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -465,6 +483,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -465,6 +483,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
static_cast<const BDataType*>(p_Bs[i]), static_cast<const BDataType*>(p_Bs[i]),
p_ds_grid, p_ds_grid,
static_cast<EDataType*>(p_Es[i]), static_cast<EDataType*>(p_Es[i]),
M,
N,
K,
a_grid_desc_m_k, a_grid_desc_m_k,
b_grid_desc_n_k, b_grid_desc_n_k,
ds_grid_desc_m_n, ds_grid_desc_m_n,
...@@ -473,16 +494,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -473,16 +494,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock, ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map, local_b2c_tile_map,
BlockStart, BlockStart,
BlockEnd}); BlockEnd});
} }
group_id++;
} }
} }
// private: // private:
index_t group_count_; index_t group_count_;
index_t skipped_group_count_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
...@@ -560,11 +582,16 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -560,11 +582,16 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm, const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
GemmBiasTransKernelArg, GemmBiasTransKernelArg,
GemmSpec,
Block2ETileMap,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
has_main_k_block_loop_>; has_main_k_block_loop_>;
const index_t grid_size_grp = arg.gemm_desc_kernel_arg_[0].BlockEnd_ -
arg.gemm_desc_kernel_arg_[0].BlockStart_;
return launch_and_time_kernel( return launch_and_time_kernel(
stream_config, stream_config,
kernel, kernel,
...@@ -573,6 +600,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -573,6 +600,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
0, 0,
cast_pointer_to_constant_address_space(arg.p_workspace_), cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
grid_size_grp,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_); arg.c_element_op_);
...@@ -600,8 +628,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -600,8 +628,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if((ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) + if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
arg.skipped_group_count_) != arg.group_count_)
{ {
return false; return false;
} }
......
...@@ -585,7 +585,8 @@ struct OffsettedBlockToCTileMap ...@@ -585,7 +585,8 @@ struct OffsettedBlockToCTileMap
{ {
using underlying_type = UnderlyingBlockToCTileMap; using underlying_type = UnderlyingBlockToCTileMap;
OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start) __host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t block_start)
{ {
block_to_ctile_map_ = block_to_ctile_map; block_to_ctile_map_ = block_to_ctile_map;
block_start_ = block_start; block_start_ = block_start;
......
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
namespace ck { namespace ck {
// GEMM: // GEMM:
...@@ -331,6 +334,97 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -331,6 +334,97 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using DsGridPointer = decltype(MakeDsGridPointer()); using DsGridPointer = decltype(MakeDsGridPointer());
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization;
template <typename ALayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
template <typename BLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
template <typename ELayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
template <typename DsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return MakeEGridDescriptor_M_N<DLayout, GemmSpec>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
...@@ -760,6 +854,95 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -760,6 +854,95 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}); });
} }
} }
template <bool HasMainKBlockLoop,
GemmSpecialization GemmSpec,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
#if 0
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
#endif
typename Block2ETileMap>
__device__ static void Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const index_t StrideDs[],
const index_t StrideE,
#if 0
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
#endif
const Block2ETileMap& block_2_etile_map)
{
// tensor descriptors for problem definiton
const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
const auto b_grid_desc_n_k = MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
DsGridDesc_M_N ds_grid_desc_m_n;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
});
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_ak0_m_ak1 = MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
static_for<0, NumDTensor, 1>{}([&](auto j) {
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
});
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
}; };
} // namespace ck } // namespace ck
...@@ -11,9 +11,11 @@ cmake ...@@ -11,9 +11,11 @@ cmake
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker \ -D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker \
-save-temps=$PWD" \ -save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=OFF \
-D GPU_TARGETS="gfx908;gfx90a;gfx940" \ -D GPU_TARGETS="gfx90a" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \ -D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
#-D GPU_TARGETS="gfx908;gfx90a;gfx940" \
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment