Commit 91075f0f authored by Jing Zhang's avatar Jing Zhang
Browse files

clean deviceop

parent c0264b8f
......@@ -317,6 +317,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
#if 0
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
......@@ -325,6 +326,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
#endif
template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMapMLoops
......@@ -483,6 +485,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
std::array<index_t, NumDTensor> StrideDs_;
index_t StrideE_;
#if 0
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
......@@ -498,6 +501,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
#endif
};
// Argument
......@@ -591,12 +595,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
M, N, gemm_descs[i].stride_Ds_[j]);
});
#if 0
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
const auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
#endif
const auto e_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideC);
......@@ -604,7 +610,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// block-to-e-tile map
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
grid_size_grp = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
if(group_id * grid_size_grp != grid_size_)
{
......@@ -619,41 +625,24 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
e_grid_desc_m_n,
local_b2c_tile_map))
{
// tensor descriptors for block/thread-wise copy
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock;
static_for<0, NumDTensor, 1>{}([&](auto j) {
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n[j]);
gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{
p_As.size() == 0 ? nullptr : p_As[i],
p_Bs.size() == 0 ? nullptr : p_Bs[i],
p_ds_grid,
p_Es[i],
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
});
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n);
gemm_desc_kernel_arg_.push_back(
GemmBiasTransKernelArg{p_As.size() == 0 ? nullptr : p_As[i],
p_Bs.size() == 0 ? nullptr : p_Bs[i],
p_ds_grid,
p_Es[i],
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
local_b2c_tile_map});
}
else
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
group_id++;
......@@ -674,6 +663,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
const void* grouped_gemm_kernel_args_dev;
index_t grid_size_;
index_t grid_size_grp;
};
// Invoker
......@@ -691,51 +681,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
#if DEBUG_LOG
std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I1)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2)
<< "}";
std::cout << ", arg.b_grid_desc_bk0_n_bk1_{"
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I0)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I1)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I2)
<< "}";
const auto KPad =
GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, 1);
static_for<0, NumDTensor, 1>{}([&](auto j) {
std::cout << ", arg.d" << i << "_grid_desc_m_n_{"
<< arg.gemm_desc_kernel_arg_[i].ds_grid_desc_m_n_[j].GetLength(I0)
<< ", "
<< arg.gemm_desc_kernel_arg_[i].ds_grid_desc_m_n_[j].GetLength(I1)
<< "}";
});
std::cout << ", arg.e_grid_desc_m_n_{ "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
#endif
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_,
arg.gemm_desc_kernel_arg_[i].ds_grid_desc_m_n_,
arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_,
arg.gemm_desc_kernel_arg_[i].block_2_etile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop)
{
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
}
......@@ -773,8 +722,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CDEElementwiseOperation,
has_main_k_block_loop_>;
const index_t grid_size_grp = arg.grid_size_ / arg.group_count_;
const void* kernel_args_dev = nullptr;
if(arg.grouped_gemm_kernel_args_dev != nullptr)
......@@ -817,7 +764,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
0,
cast_pointer_to_constant_address_space(kernel_args_dev),
arg.gemm_desc_kernel_arg_.size(),
grid_size_grp,
arg.grid_size_grp,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
......
......@@ -200,6 +200,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
c_block_size * sizeof(CShuffleDataType));
}
#if 0
// A desc for source in blockwise copy
template <typename AGridDesc_M_K>
__host__ __device__ static constexpr auto
......@@ -233,6 +234,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
#endif
__host__ __device__ static auto CalculateMPadded(index_t M)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment