Commit 91075f0f authored by Jing Zhang's avatar Jing Zhang
Browse files

clean deviceop

parent c0264b8f
...@@ -317,6 +317,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -317,6 +317,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
#if 0
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
...@@ -325,6 +326,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -325,6 +326,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>; GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
#endif
template <typename UnderlyingBlockToCTileMap> template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMapMLoops struct OffsettedBlockToCTileMapMLoops
...@@ -483,6 +485,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -483,6 +485,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
std::array<index_t, NumDTensor> StrideDs_; std::array<index_t, NumDTensor> StrideDs_;
index_t StrideE_; index_t StrideE_;
#if 0
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
...@@ -498,6 +501,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -498,6 +501,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// block-to-e-tile map // block-to-e-tile map
Block2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
#endif
}; };
// Argument // Argument
...@@ -591,12 +595,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -591,12 +595,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
M, N, gemm_descs[i].stride_Ds_[j]); M, N, gemm_descs[i].stride_Ds_[j]);
}); });
#if 0
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
const auto a_grid_desc_ak0_m_ak1 = const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
const auto b_grid_desc_bk0_n_bk1 = const auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
#endif
const auto e_grid_desc_m_n = const auto e_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideC); DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideC);
...@@ -604,7 +610,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -604,7 +610,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// block-to-e-tile map // block-to-e-tile map
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n}; const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); grid_size_grp = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
if(group_id * grid_size_grp != grid_size_) if(group_id * grid_size_grp != grid_size_)
{ {
...@@ -619,41 +625,24 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -619,41 +625,24 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
e_grid_desc_m_n, e_grid_desc_m_n,
local_b2c_tile_map)) local_b2c_tile_map))
{ {
// tensor descriptors for block/thread-wise copy gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock p_As.size() == 0 ? nullptr : p_As[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock; p_Bs.size() == 0 ? nullptr : p_Bs[i],
p_ds_grid,
static_for<0, NumDTensor, 1>{}([&](auto j) { p_Es[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = M,
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( N,
ds_grid_desc_m_n[j]); K,
StrideA,
StrideB,
StrideDs,
StrideC,
}); });
}
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = else
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( {
e_grid_desc_m_n); throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
gemm_desc_kernel_arg_.push_back(
GemmBiasTransKernelArg{p_As.size() == 0 ? nullptr : p_As[i],
p_Bs.size() == 0 ? nullptr : p_Bs[i],
p_ds_grid,
p_Es[i],
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
local_b2c_tile_map});
} }
group_id++; group_id++;
...@@ -674,6 +663,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -674,6 +663,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
const void* grouped_gemm_kernel_args_dev; const void* grouped_gemm_kernel_args_dev;
index_t grid_size_; index_t grid_size_;
index_t grid_size_grp;
}; };
// Invoker // Invoker
...@@ -691,51 +681,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -691,51 +681,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{ {
#if DEBUG_LOG const auto KPad =
std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{" GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, 1);
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I1)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2)
<< "}";
std::cout << ", arg.b_grid_desc_bk0_n_bk1_{"
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I0)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I1)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I2)
<< "}";
static_for<0, NumDTensor, 1>{}([&](auto j) { if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop)
std::cout << ", arg.d" << i << "_grid_desc_m_n_{"
<< arg.gemm_desc_kernel_arg_[i].ds_grid_desc_m_n_[j].GetLength(I0)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].ds_grid_desc_m_n_[j].GetLength(I1)
<< "}";
});
std::cout << ", arg.e_grid_desc_m_n_{ "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
#endif
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_,
arg.gemm_desc_kernel_arg_[i].ds_grid_desc_m_n_,
arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_,
arg.gemm_desc_kernel_arg_[i].block_2_etile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
{ {
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
} }
...@@ -773,8 +722,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -773,8 +722,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CDEElementwiseOperation, CDEElementwiseOperation,
has_main_k_block_loop_>; has_main_k_block_loop_>;
const index_t grid_size_grp = arg.grid_size_ / arg.group_count_;
const void* kernel_args_dev = nullptr; const void* kernel_args_dev = nullptr;
if(arg.grouped_gemm_kernel_args_dev != nullptr) if(arg.grouped_gemm_kernel_args_dev != nullptr)
...@@ -817,7 +764,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -817,7 +764,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
0, 0,
cast_pointer_to_constant_address_space(kernel_args_dev), cast_pointer_to_constant_address_space(kernel_args_dev),
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
grid_size_grp, arg.grid_size_grp,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_); arg.c_element_op_);
......
...@@ -200,6 +200,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -200,6 +200,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
c_block_size * sizeof(CShuffleDataType)); c_block_size * sizeof(CShuffleDataType));
} }
#if 0
// A desc for source in blockwise copy // A desc for source in blockwise copy
template <typename AGridDesc_M_K> template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -233,6 +234,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -233,6 +234,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
#endif
__host__ __device__ static auto CalculateMPadded(index_t M) __host__ __device__ static auto CalculateMPadded(index_t M)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment