Commit 8e20f747 authored by Paul's avatar Paul
Browse files

Format

parent 6e4a1075
...@@ -498,94 +498,99 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -498,94 +498,99 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
} }
}; };
static bool IsSupportedArgument(const Argument& arg) static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
// check vector load/store // check vector load/store
{ using Row = ck::tensor_layout::gemm::RowMajor;
using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A // check vector load of A
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2) if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{ {
if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{ {
// FIXME: not rigorous return false;
if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
} }
else }
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{ {
return false; return false;
} }
}
else
{
return false;
}
// check vector laod of B // check vector laod of B
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2) if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{ {
if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{ {
// FIXME: not rigorous return false;
if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
} }
else }
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{
// FIXME: not rigorous
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{ {
return false; return false;
} }
}
else
{
return false;
}
// check vector load of Ds // check vector load of Ds
// only support RowMajor for now // only support RowMajor for now
bool all_valid = true; bool all_valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(!is_same_v<DLayout, Row>) static_for<0, NumDTensor, 1>{}([&](auto i) {
{ using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
all_valid = false;
}
});
if(!all_valid) if constexpr(!is_same_v<DLayout, Row>)
{ {
return false; all_valid = false;
} }
});
// check vector store of E if(!all_valid)
// only support RowMajor for now {
if constexpr(is_same_v<ELayout, Row>) return false;
{ }
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{ // check vector store of E
return false; // only support RowMajor for now
} if constexpr(is_same_v<ELayout, Row>)
} {
else if(NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{ {
return false; return false;
} }
} }
else
{
return false;
}
return true;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, static bool IsSupportedArgument(const Argument& arg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{
return false;
}
return IsSupported(arg.MRaw_, arg.NRaw_, arg.KRaw_) and
GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_, arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_, arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_, arg.e_grid_desc_m_n_,
...@@ -708,6 +713,171 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -708,6 +713,171 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return str.str(); return str.str();
} }
template <class ADesc, class BDesc, class DsDesc, class EDesc>
struct Descriptor
{
static constexpr auto ds_tuple()
{
return transform_tuples(
[&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
DsDesc{});
}
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{})))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{})))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_tuple()))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
using Block2ETileMap = remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(
DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k;
BGridDesc_N_K b_grid_desc_n_k;
DsGridDesc_M_N ds_grid_desc_m_n;
EGridDesc_M_N e_grid_desc_m_n;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock;
// block-to-e-tile map
Block2ETileMap block_2_etile_map;
// element-wise op
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
CDEElementwiseOperation cde_element_op;
// for checking vector load/store
index_t MRaw;
index_t NRaw;
index_t KRaw;
bool has_main_k_block_loop = true;
constexpr Descriptor(ADesc a,
BDesc b,
DsDesc ds,
EDesc e,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CDEElementwiseOperation cde_element_op_)
: a_grid_desc_m_k{DeviceOp::matrix_padder.PadADescriptor_M_K(a)},
b_grid_desc_n_k{DeviceOp::matrix_padder.PadBDescriptor_N_K(b)},
ds_grid_desc_m_n{transform_tuples(
[&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
ds)},
e_grid_desc_m_n{DeviceOp::matrix_padder.PadCDescriptor_M_N(e)},
a_grid_desc_ak0_m_ak1{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k)},
b_grid_desc_bk0_n_bk1{
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k)},
ds_grid_desc_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
transform_tuples(
[&](auto d) constexpr {
return DeviceOp::matrix_padder.PadCDescriptor_M_N(d);
},
ds))},
e_grid_desc_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n)},
block_2_etile_map{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n)},
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
cde_element_op{cde_element_op_},
MRaw{e.GetLength(I0)},
NRaw{e.GetLength(I1)},
KRaw{a.GetLength(I1)}
{
}
constexpr bool IsValid() const
{
return GridwiseGemm::CheckValidity((a_grid_desc_m_k),
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
block_2_etile_map) and
IsSupported(MRaw, NRaw, KRaw);
}
constexpr index_t GetBlockSize() const { return BlockSize; }
constexpr index_t GetGridSize() const
{
return block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
}
};
template <class ADesc, class BDesc, class DsDesc, class EDesc>
static constexpr auto
make_descriptor(ADesc a,
BDesc b,
DsDesc ds,
EDesc e,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
CDEElementwiseOperation cde_element_op = CDEElementwiseOperation{})
{
return Descriptor<ADesc, BDesc, DsDesc, EDesc>(
a, b, ds, e, a_element_op, b_element_op, cde_element_op);
}
template <class Desc, class DsPointer>
__device__ static void Run(const Desc& desc,
const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid)
{
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
assert(desc.is_valid);
if(desc.has_main_k_block_loop)
{
GridwiseGemm::template Run<true>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
desc.cde_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
desc.block_2_etile_map);
}
else
{
GridwiseGemm::template Run<false>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
desc.cde_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
desc.block_2_etile_map);
}
}
}; };
} // namespace device } // namespace device
......
...@@ -24,10 +24,10 @@ struct BlockToCTileMap_M00_N0_M01 ...@@ -24,10 +24,10 @@ struct BlockToCTileMap_M00_N0_M01
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{}; static constexpr auto I3 = Number<3>{};
__host__ __device__ BlockToCTileMap_M00_N0_M01() = default; __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01() = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n, __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 1) index_t M01 = 1)
: M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01)) : M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01))
{ {
} }
...@@ -51,8 +51,8 @@ struct BlockToCTileMap_M00_N0_M01 ...@@ -51,8 +51,8 @@ struct BlockToCTileMap_M00_N0_M01
} }
template <typename CTileIdx, typename CTileDim> template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const const CTileDim& c_tile_dim) const
{ {
if constexpr(DeviceCTileIndexCheck) if constexpr(DeviceCTileIndexCheck)
return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); return DefaultValidCTileIndex(c_tile_idx, c_tile_dim);
...@@ -60,7 +60,7 @@ struct BlockToCTileMap_M00_N0_M01 ...@@ -60,7 +60,7 @@ struct BlockToCTileMap_M00_N0_M01
return true; return true;
} }
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{ {
if constexpr(DeviceCTileIndexCheck) if constexpr(DeviceCTileIndexCheck)
return true; // validity check moved to kernel return true; // validity check moved to kernel
...@@ -120,25 +120,27 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void> ...@@ -120,25 +120,27 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default; __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt&) = __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
default; const BlockToCTileMap_M00_N0_M01Adapt&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt&&) = __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(
default; BlockToCTileMap_M00_N0_M01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt& __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default; operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt& __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt&
operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default; operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) __host__
__device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
: M_(M), N_(N), M01_(M01) : M_(M), N_(N), M01_(M01)
{ {
} }
template <typename CGridDesc_M_N> template <typename CGridDesc_M_N>
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, __host__
index_t M01 = 8) __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 8)
: BlockToCTileMap_M00_N0_M01Adapt( : BlockToCTileMap_M00_N0_M01Adapt(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01) c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
{ {
...@@ -232,8 +234,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void> ...@@ -232,8 +234,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
} }
template <typename CTileIdx, typename CTileDim> template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const const CTileDim& /* c_tile_dim */) const
{ {
return true; // always valid provided that user gets grid size from CalculateGridSize() return true; // always valid provided that user gets grid size from CalculateGridSize()
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment