Commit cf95b944 authored by Chao Liu's avatar Chao Liu
Browse files

add ds

parent 0267bf47
...@@ -643,6 +643,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -643,6 +643,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_M_K, AGridDesc_M_K,
BGridDesc_N_K, BGridDesc_N_K,
StaticallyIndexedArray<EGridDesc_M_N, NumDTensor>,
EGridDesc_M_N, EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
...@@ -761,12 +762,29 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -761,12 +762,29 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
block_2_etile_map_ = Block2ETileMap{e_grid_desc_m_n_}; block_2_etile_map_ = Block2ETileMap{e_grid_desc_m_n_};
if(GridwiseGemm::CheckValidity( if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
a_grid_desc_m_k_, b_grid_desc_n_k_, e_grid_desc_m_n_, block_2_etile_map_)) b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
{ {
e_grid_desc_mblock_mperblock_nblock_nperblock_ = e_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n_); e_grid_desc_m_n_);
// populate pointer and desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
ds_grid_desc_m_n_[i] = DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
ds_n_k_wos_lengths[i], ds_n_k_wos_strides[i]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n_[i]);
});
} }
} }
...@@ -780,13 +798,11 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -780,13 +798,11 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// tensor descriptors // tensor descriptors
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
// FIXME: don't assume D and E desc are the same type
StaticallyIndexedArray<EGridDesc_M_N, NumDTensor> ds_grid_desc_m_n_; StaticallyIndexedArray<EGridDesc_M_N, NumDTensor> ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
// FIXME: don't assume D and E desc are the same type
StaticallyIndexedArray< StaticallyIndexedArray<
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
NumDTensor> NumDTensor>
...@@ -833,6 +849,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -833,6 +849,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
#endif #endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.block_2_etile_map_)) arg.block_2_etile_map_))
{ {
...@@ -1016,6 +1033,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS ...@@ -1016,6 +1033,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// check Gridwise GEMM // check Gridwise GEMM
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
arg.block_2_etile_map_); arg.block_2_etile_map_);
} }
......
...@@ -37,6 +37,7 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype ...@@ -37,6 +37,7 @@ template <typename ABDataType, // FIXME: don't assume A/B have same datatype
InMemoryDataOperationEnum EGlobalMemoryDataOperation, InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename AGridDesc_M_K, typename AGridDesc_M_K,
typename BGridDesc_N_K, typename BGridDesc_N_K,
typename DsGridDesc_M_N,
typename EGridDesc_M_N, typename EGridDesc_M_N,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
...@@ -225,6 +226,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -225,6 +226,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
template <typename Block2ETileMap> template <typename Block2ETileMap>
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k, const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
...@@ -236,11 +238,29 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -236,11 +238,29 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const auto N = b_grid_desc_n_k.GetLength(I0); const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1); const auto K = a_grid_desc_m_k.GetLength(I1);
// check consistency of desc
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1)))
{
return false;
}
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!valid)
{
return false; return false;
}
// check tile size
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
return false; return false;
}
// check gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock; const auto num_k_loop = K / KPerBlock;
...@@ -250,6 +270,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -250,6 +270,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
return false; return false;
} }
// check block-to-E-tile
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{ {
return false; return false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment