Unverified Commit 1344a0f2 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

Simplify kernel argument of device operator DeviceGemm_Xdl_CShuffle<> (#696)



* Remove M/N/KPad local variables

* Use M/N/KPad to name padded lengths

* Replace duplicated local variable by parameters

* Rename variables M/N/KRaw to M/N/K

* Move AK0/BK0 compute logic into GridwiseGemm

* Use macro to shorten code

* Move CalculateGridSize() logic into GridwiseGemm

* Add comment to credit the implementation source

* Reuse the existing implementation

* Remove no-longer used data members

* Remove elementwise-op objects from interfaces

* Reserve kernel arg as whole object in interfaces

* Remove redundant data member

* Make 3rd type parameter optional

* Remove unnesscary type parameters

* Remove no-longer used descriptor-creation methods

* Move kernel arg type definition into GridwiseGemm

* Add macro to switch between code sections

* Move argument field computing logic into device op side

* Make utility method 'static'

* Declare special methods

* Unify MakeArgument() usage

* Adapt the new GridwiseGemm interface

* Push-down class 'GridwiseGemm::Argument' fields

* Remove no-longer used methods

* Add unused parameters

* Force copying parameters in 'Embed' ctor

* Remove no-longer used descriptors

* Fallback change on BaseArgument

* Remove macro 'INTEGER_DIVIDE_CEIL'

* Make variable naming more consistent

* Make sure methods are only invoked on right place

* Remove tailing underscore in public attribute name

* Remove necessary methods

* Hide computing logic of derived attributes

* Make new 'Embed' ctor only available for device code

* Make sure 'Embed' type args are not references

* Move check for karg.K into CheckValidity()

* Remove more integer division logic form device code

* Undo changes on Embed

* Separate 'Problem' concept out from 'Argument'

* Share same name for kernel interfaces

* Reject unsupported argument

---------
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
parent 70e4eb56
...@@ -109,30 +109,37 @@ struct BlockToCTileMap_M00_N0_M01 ...@@ -109,30 +109,37 @@ struct BlockToCTileMap_M00_N0_M01
// Rows of column-vectors // Rows of column-vectors
// This C-tile map dynamically adjusts M01 when C-tile index is out of range // This C-tile map dynamically adjusts M01 when C-tile index is out of range
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N> template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
struct BlockToCTileMap_M00_N0_M01Adapt struct BlockToCTileMap_M00_N0_M01Adapt;
template <index_t MPerBlock, index_t NPerBlock>
struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default; __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt&) =
index_t M01 = 8) default;
: M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n) __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt&&) =
default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt&
operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt&
operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8)
: M_(M), N_(N), M01_(M01)
{ {
} }
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const __host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{ {
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); const auto N0 = math::integer_divide_ceil(N, NPerBlock);
const index_t grid_size = M0 * N0;
return grid_size; return M0 * N0;
} }
template <typename TopIdx> template <typename TopIdx>
...@@ -140,8 +147,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt ...@@ -140,8 +147,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt
{ {
auto block_1d_id = idx_top[I0]; auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock); const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock); const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
...@@ -209,11 +216,36 @@ struct BlockToCTileMap_M00_N0_M01Adapt ...@@ -209,11 +216,36 @@ struct BlockToCTileMap_M00_N0_M01Adapt
return true; // always valid provided that user gets grid size from CalculateGridSize() return true; // always valid provided that user gets grid size from CalculateGridSize()
} }
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
private: private:
index_t M_;
index_t N_;
index_t M01_; index_t M01_;
CGridDesc_M_N c_grid_desc_m_n_; };
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
{
using Parent = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>;
using Parent::I0;
using Parent::I1;
using Parent::Parent;
using Parent::operator=;
__host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01 = 8)
: Parent(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
{
}
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
return Parent::CalculateGridSize(c_grid_desc_m_n.GetLength(I0),
c_grid_desc_m_n.GetLength(I1));
}
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
}; };
// 2D slices of column-vectors in 3D space // 2D slices of column-vectors in 3D space
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment