Commit 41a1466a authored by Jing Zhang's avatar Jing Zhang
Browse files

change m_loops to tile_loops

parent 36a527df
...@@ -55,7 +55,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F ...@@ -55,7 +55,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; //< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<2, 0, 1, 3>, S<2, 0, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on // clang-format on
struct ProblemSize final struct ProblemSize final
......
...@@ -82,6 +82,8 @@ __global__ void ...@@ -82,6 +82,8 @@ __global__ void
const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch};
const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n);
constexpr auto NumDTensor = DsDataType::Size(); constexpr auto NumDTensor = DsDataType::Size();
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
...@@ -94,13 +96,12 @@ __global__ void ...@@ -94,13 +96,12 @@ __global__ void
p_ds_grid_(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]); p_ds_grid_(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
}); });
auto m_loops = local_b2e_tile_map.CalculateMLoops(); index_t id_off = 0;
index_t m_id = 0; while((get_block_1d_id() - BlockStart + id_off) < local_grid_size)
do
{ {
const auto block_2_etile_map = const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, m_id); GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
GridwiseGemm:: GridwiseGemm::
template Run<HasMainKBlockLoop, GemmSpec, ALayout, BLayout, DsLayout, ELayout>( template Run<HasMainKBlockLoop, GemmSpec, ALayout, BLayout, DsLayout, ELayout>(
...@@ -122,9 +123,8 @@ __global__ void ...@@ -122,9 +123,8 @@ __global__ void
KBatch, KBatch,
block_2_etile_map); block_2_etile_map);
m_id += 1; id_off += grid_size_grp;
}
} while(m_id < m_loops);
#else #else
ignore = grid_size_grp; ignore = grid_size_grp;
ignore = gemm_descs_const; ignore = gemm_descs_const;
...@@ -201,82 +201,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -201,82 +201,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
...@@ -321,40 +245,26 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -321,40 +245,26 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
#if 0
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
#endif
template <typename UnderlyingBlockToCTileMap> template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMapMLoops struct OffsettedBlockToCTileMapMLoops
{ {
using underlying_type = UnderlyingBlockToCTileMap; using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__ __host__ __device__ OffsettedBlockToCTileMapMLoops(
OffsettedBlockToCTileMapMLoops(UnderlyingBlockToCTileMap block_to_ctile_map, UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0)
index_t block_start,
index_t mblock_id_off = 0)
{ {
block_to_ctile_map_ = block_to_ctile_map; block_to_ctile_map_ = block_to_ctile_map;
block_start_ = block_start; block_start_ = block_start;
mblock_id_off_ = mblock_id_off; id_off_ = id_off;
} }
template <typename TopIdx> template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{ {
auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( auto idx_bot = block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] - block_start_)); make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_));
return make_tuple( return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]);
idx_bot[Number<0>{}], idx_bot[Number<1>{}] + mblock_id_off_, idx_bot[Number<2>{}]);
} }
template <typename CTileIdx, typename CTileDim> template <typename CTileIdx, typename CTileDim>
...@@ -378,7 +288,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -378,7 +288,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
UnderlyingBlockToCTileMap block_to_ctile_map_; UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_; index_t block_start_;
index_t mblock_id_off_; index_t id_off_;
}; };
template <index_t MPerBlock_, index_t NPerBlock_> template <index_t MPerBlock_, index_t NPerBlock_>
...@@ -414,21 +324,17 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -414,21 +324,17 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
{ {
} }
__host__ __device__ constexpr index_t CalculateMLoops() const __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
{
return math::integer_divide_ceil(M_, MPerBlock_);
}
__host__ constexpr index_t CalculateGridSize(index_t /*M*/, index_t N) const
{ {
const auto M0 = 1; // math::integer_divide_ceil(M, MPerBlock); const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock); const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return M0 * N0 * KBatch_; return M0 * N0 * KBatch_;
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const __host__ __device__ constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{ {
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
} }
...@@ -444,7 +350,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -444,7 +350,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
{ {
auto block_1d_id = idx_top[I0]; auto block_1d_id = idx_top[I0];
const auto M0 = 1; // math::integer_divide_ceil(M_, MPerBlock_); const auto M0 = math::integer_divide_ceil(M_, MPerBlock_);
const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); const auto N0 = math::integer_divide_ceil(N_, NPerBlock_);
block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups
...@@ -495,24 +401,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -495,24 +401,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
index_t StrideA_, StrideB_; index_t StrideA_, StrideB_;
std::array<index_t, NumDTensor> StrideDs_; std::array<index_t, NumDTensor> StrideDs_;
index_t StrideE_; index_t StrideE_;
#if 0
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
#endif
}; };
// Argument // Argument
...@@ -561,13 +449,19 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -561,13 +449,19 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
index_t group_id = 0; index_t group_id = 0;
sum_of_m = gemm_descs[0].M_;
const index_t AverM = sum_of_m / group_count_;
const index_t N = gemm_descs[0].N_;
const index_t K = gemm_descs[0].K_;
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
const index_t M = gemm_descs[i].M_; if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_)
const index_t N = gemm_descs[i].N_; {
const index_t K = gemm_descs[i].K_; throw std::runtime_error("wrong! M/N/K is not identical");
}
a_mtx_mraw_kraw_.emplace_back(M, K); a_mtx_mraw_kraw_.emplace_back(sum_of_m, K);
b_mtx_nraw_kraw_.emplace_back(N, K); b_mtx_nraw_kraw_.emplace_back(N, K);
const index_t StrideA = gemm_descs[i].stride_A_; const index_t StrideA = gemm_descs[i].stride_A_;
...@@ -584,12 +478,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -584,12 +478,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static_cast<const DDataType*>(p_Ds.size() == 0 ? nullptr : p_Ds[i][j]); static_cast<const DDataType*>(p_Ds.size() == 0 ? nullptr : p_Ds[i][j]);
}); });
// tensor descriptors for problem definiton
// const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K(M, K, StrideA);
// const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K(K, N, StrideB);
// DsGridDesc_M_N ds_grid_desc_m_n;
std::array<index_t, NumDTensor> StrideDs; std::array<index_t, NumDTensor> StrideDs;
static_for<0, NumDTensor, 1>{}([&](auto j) { static_for<0, NumDTensor, 1>{}([&](auto j) {
...@@ -602,27 +490,20 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -602,27 +490,20 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
} }
StrideDs[j] = gemm_descs[i].stride_Ds_[j]; StrideDs[j] = gemm_descs[i].stride_Ds_[j];
// ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
// M, N, gemm_descs[i].stride_Ds_[j]);
}); });
#if 0
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
const auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
#endif
const auto e_grid_desc_m_n = const auto e_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE); GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
AverM, N, StrideE);
// block-to-e-tile map // block-to-e-tile map
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch}; const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch};
grid_size_grp = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); grid_size_grp = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
// std::cout << "group_id: " << group_id << " grid_size_grp: " << grid_size_grp
//<< std::endl;
if(group_id * grid_size_grp != grid_size_) if(group_id * grid_size_grp != grid_size_)
{ {
throw std::runtime_error("wrong! grid_size_grp is not identical!"); throw std::runtime_error("wrong! grid_size_grp is not identical!");
...@@ -638,7 +519,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -638,7 +519,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
if(!GridwiseGemm:: if(!GridwiseGemm::
template CheckValidity<ALayout, BLayout, DsLayout, ELayout, GemmSpec>( template CheckValidity<ALayout, BLayout, DsLayout, ELayout, GemmSpec>(
M, N, K, StrideA, StrideB, StrideDs, StrideE, 1)) AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
...@@ -649,7 +530,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -649,7 +530,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
p_Bs.size() == 0 ? nullptr : p_Bs[i], p_Bs.size() == 0 ? nullptr : p_Bs[i],
p_ds_grid, p_ds_grid,
p_Es[i], p_Es[i],
M, AverM,
N, N,
K, K,
StrideA, StrideA,
...@@ -677,6 +558,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -677,6 +558,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
index_t grid_size_; index_t grid_size_;
index_t grid_size_grp; index_t grid_size_grp;
index_t sum_of_m;
}; };
// Invoker // Invoker
...@@ -735,38 +617,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -735,38 +617,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CDEElementwiseOperation, CDEElementwiseOperation,
has_main_k_block_loop_>; has_main_k_block_loop_>;
const void* kernel_args_dev = nullptr; if(arg.grouped_gemm_kernel_args_dev == nullptr)
if(arg.grouped_gemm_kernel_args_dev != nullptr)
{
kernel_args_dev = arg.grouped_gemm_kernel_args_dev;
}
else
{ {
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr");
{
if(arg.gemm_desc_kernel_arg_[i].a_ptr_ == nullptr ||
arg.gemm_desc_kernel_arg_[i].b_ptr_ == nullptr ||
arg.gemm_desc_kernel_arg_[i].e_ptr_ == nullptr)
{
throw std::runtime_error("wrong! p_a/b/c_grid is nullptr");
}
}
if(arg.p_workspace_ == nullptr)
{
throw std::runtime_error("wrong! arg.p_workspace_ == nullptr");
}
hipGetErrorString(
hipMemcpyWithStream(arg.p_workspace_,
grouped_gemm_kernel_args.data(),
grouped_gemm_kernel_args.size() *
sizeof(GroupedGemmKernelArgument<NumDTensor>),
hipMemcpyHostToDevice,
stream_config.stream_id_));
kernel_args_dev = arg.p_workspace_;
} }
return launch_and_time_kernel( return launch_and_time_kernel(
...@@ -775,7 +628,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -775,7 +628,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
dim3(arg.grid_size_), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
cast_pointer_to_constant_address_space(kernel_args_dev), cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.grid_size_grp, arg.grid_size_grp,
k_batch, k_batch,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment