Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
241baec1
Commit
241baec1
authored
Dec 16, 2024
by
Adam Osewski
Browse files
Refactor universal gemm policy.
parent
77a38e02
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
198 additions
and
53 deletions
+198
-53
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+2
-2
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+5
-6
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+35
-19
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+132
-26
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+24
-0
No files found.
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
241baec1
...
@@ -25,7 +25,7 @@ struct GemmKernel
...
@@ -25,7 +25,7 @@ struct GemmKernel
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
//
using CAccDataType = remove_cvref_t<typename GemmPipeline::CD
ata
T
ype
>;
//
Below type is actually accumulation d
ata
t
ype
- the output of block GEMM.
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
...
@@ -238,7 +238,7 @@ struct GemmKernel
...
@@ -238,7 +238,7 @@ struct GemmKernel
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
// Run GEMM cooperatively by whole wo
k
rgroup.
// Run GEMM cooperatively by whole wor
k
group.
auto
c_block_tile
=
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
View file @
241baec1
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag
mem_bgmem_creg_v1_default
_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_
universal_
pipeline_ag
_bg_cr
_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
...
@@ -37,7 +37,7 @@ struct BaseGemmPipelineAgBgCrCompV3
...
@@ -37,7 +37,7 @@ struct BaseGemmPipelineAgBgCrCompV3
// LocalPreFillStages: 1
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineA
GmemBGmemCRegV1Default
Policy
>
template
<
typename
Problem
,
typename
Policy
=
Universal
GemmPipelineA
gBgCr
Policy
>
struct
GemmPipelineAgBgCrCompV3
:
public
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
struct
GemmPipelineAgBgCrCompV3
:
public
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
{
{
using
Base
=
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
;
using
Base
=
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
;
...
@@ -62,15 +62,14 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -62,15 +62,14 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
P
roblem
::
VectorSizeA
;
static
constexpr
index_t
VectorSizeA
=
P
olicy
::
template
GetVectorSizeA
<
Problem
>()
;
static
constexpr
index_t
VectorSizeB
=
P
roblem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeB
=
P
olicy
::
template
GetVectorSizeB
<
Problem
>()
;
static
constexpr
index_t
VectorSizeC
=
P
roblem
::
VectorSizeC
;
static
constexpr
index_t
VectorSizeC
=
P
olicy
::
template
GetVectorSizeC
<
Problem
>()
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
241baec1
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#pragma once
#pragma once
#include "ck_tile/
ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler
.hpp"
#include "ck_tile/
core
.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -11,10 +11,10 @@ template <typename ADataType_,
...
@@ -11,10 +11,10 @@ template <typename ADataType_,
typename
BDataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
BlockGemmShape_
,
typename
TileGemm
Traits_
>
typename
Traits_
>
struct
GemmPipelineProblemBase
struct
GemmPipelineProblemBase
{
{
using
Gemm
Traits
=
remove_cvref_t
<
TileGemm
Traits_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
...
@@ -22,19 +22,19 @@ struct GemmPipelineProblemBase
...
@@ -22,19 +22,19 @@ struct GemmPipelineProblemBase
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
ALayout
=
remove_cvref_t
<
typename
Gemm
Traits
::
ALayout
>
;
using
ALayout
=
remove_cvref_t
<
typename
Traits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Gemm
Traits
::
BLayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Traits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Gemm
Traits
::
CLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Traits
::
CLayout
>
;
static
constexpr
index_t
VectorLoadSize
=
GemmTraits
::
_VectorSize
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadM
=
Gemm
Traits
::
kPadM
;
static
constexpr
bool
kPadM
=
Traits
::
kPadM
;
static
constexpr
bool
kPadN
=
Gemm
Traits
::
kPadN
;
static
constexpr
bool
kPadN
=
Traits
::
kPadN
;
static
constexpr
bool
kPadK
=
Gemm
Traits
::
kPadK
;
static
constexpr
bool
kPadK
=
Traits
::
kPadK
;
static
constexpr
auto
Scheduler
=
GemmPipelineScheduler
::
Default
;
static
constexpr
auto
Scheduler
=
GemmPipelineScheduler
::
Default
;
static
constexpr
index_t
VectorLoadSize
=
Traits
::
_VectorSize
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentA
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentA
()
{
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
...
@@ -128,27 +128,43 @@ template <typename ADataType_,
...
@@ -128,27 +128,43 @@ template <typename ADataType_,
typename
BDataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
BlockGemmShape_
,
typename
TileGemm
Traits_
>
typename
Traits_
>
using
GemmPipelineProblem
=
using
GemmPipelineProblem
=
GemmPipelineProblemBase
<
ADataType_
,
BDataType_
,
CDataType_
,
BlockGemmShape_
,
TileGemm
Traits_
>
;
GemmPipelineProblemBase
<
ADataType_
,
BDataType_
,
CDataType_
,
BlockGemmShape_
,
Traits_
>
;
template
<
typename
ADataType_
,
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
BlockGemmShape_
,
typename
TileGemm
Traits_
,
typename
Traits_
,
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
bool
HasHotLoop_
=
true
,
bool
HasHotLoop_
=
true
,
TailNumber
TailNum_
=
TailNumber
::
Full
>
TailNumber
TailNum_
=
TailNumber
::
Full
>
struct
UniversalGemmPipelineProblem
:
public
GemmPipelineProblemBase
<
ADataType_
,
struct
UniversalGemmPipelineProblem
BDataType_
,
CDataType_
,
BlockGemmShape_
,
TileGemmTraits_
>
{
{
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
ALayout
=
remove_cvref_t
<
typename
Traits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Traits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Traits
::
CLayout
>
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadM
=
Traits
::
kPadM
;
static
constexpr
bool
kPadN
=
Traits
::
kPadN
;
static
constexpr
bool
kPadK
=
Traits
::
kPadK
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
TailNum
=
TailNum_
;
static
constexpr
auto
TailNum
=
TailNum_
;
static
constexpr
bool
TransposeC
=
Traits
::
TransposeC
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
241baec1
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -16,8 +17,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -16,8 +17,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
bool
TransposeC
=
true
;
template
<
typename
Problem
,
typename
DataType
,
index_t
MNPerBlock
>
template
<
typename
Problem
,
typename
DataType
,
index_t
MNPerBlock
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorLoadSize
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorLoadSize
()
{
{
...
@@ -25,6 +24,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -25,6 +24,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
elements_per_thread
=
MNPerBlock
*
KPerBlock
/
BlockSize
;
constexpr
index_t
elements_per_thread
=
MNPerBlock
*
KPerBlock
/
BlockSize
;
// Assume DataType is even!
if
constexpr
(
elements_per_thread
%
(
16
/
sizeof
(
DataType
))
==
0
)
if
constexpr
(
elements_per_thread
%
(
16
/
sizeof
(
DataType
))
==
0
)
{
{
return
(
16
/
sizeof
(
DataType
));
return
(
16
/
sizeof
(
DataType
));
...
@@ -49,6 +50,95 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -49,6 +50,95 @@ struct UniversalGemmPipelineAgBgCrPolicy
}
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeA
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
return
GetVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeB
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
return
GetVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
}
/**
* @brief Get the vector store size for C tensor.
*
* @tparam Problem - Gemm pipeline problem class.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetBlockGemm
<
Problem
>
())
>
;
using
WG
=
typename
BlockGemm
::
WarpGemm
;
constexpr
bool
TransposeC
=
Problem
::
TransposeC
;
using
CLayout
=
typename
Problem
::
CLayout
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
// constexpr auto c_warp_x_lengths = CWarpDstr::get_lengths();
// using c_warp_hs_lengths = typename CWarpDstrEncoding::HsLengthss;
// N is contiguous dimension
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
constexpr
(
TransposeC
)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
// static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
// c_warp_y_lengths.get(number<NDimY-1>{}));
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
else
{
// In this case each thread has just a single item in Ndim
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
}
// M is contiguous dimension
else
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
if
constexpr
(
TransposeC
)
{
// In this case each thread has just a single item in Mdim
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
// static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
// c_warp_y_lengths.get(number<NDimY-1>{}));
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
}
else
{
static_assert
(
false
,
"Unsupported CLayout!"
);
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
{
...
@@ -180,29 +270,28 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -180,29 +270,28 @@ struct UniversalGemmPipelineAgBgCrPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
{
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M1
=
GetVectorSizeA
<
Problem
>
(
);
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
constexpr
index_t
elem_per_thr
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
elem_per_thr
/
M1
;
// # of loads per thr
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
KPack
=
GetVectorSizeA
<
Problem
>
();
constexpr
index_t
KPack
=
GetVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
static_assert
(
KPack
%
K3
==
0
);
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
M0
)
==
0
)
if
constexpr
(
get_warp_size
()
%
(
K2
*
M0
)
==
0
)
{
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
M0
);
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
M0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
...
@@ -217,6 +306,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -217,6 +306,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
...
@@ -228,15 +318,21 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -228,15 +318,21 @@ struct UniversalGemmPipelineAgBgCrPolicy
}
}
else
else
{
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
// In RowMajor scenario we usually want to read whole KPerBlock tile dim
constexpr
index_t
K1
=
GetVectorSizeA
<
Problem
>
();
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
// Coalesce reading for whole workgroup - workgroup raked pattern
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
{
{
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
static_assert
(
M0
*
M1
*
M2
==
MPerBlock
,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!"
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
...
@@ -245,10 +341,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -245,10 +341,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
sequence
<
0
,
1
>>
{});
}
}
// Coalesce reading for each wavefront - wavefront raked pattern
else
else
{
{
constexpr
index_t
M0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
MPerBlock
/
(
M2
*
M0
);
constexpr
index_t
M1
=
MPerBlock
/
(
M2
*
M0
);
static_assert
(
M0
*
M1
*
M2
==
MPerBlock
,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!"
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
...
@@ -263,22 +364,20 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -263,22 +364,20 @@ struct UniversalGemmPipelineAgBgCrPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N1
=
GetVectorSizeB
<
Problem
>
(
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
constexpr
index_t
elem_per_thr
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
static_assert
(
elem_per_thr
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
K3
=
elem_per_thr
/
N1
;
constexpr
index_t
KPack
=
GetVector
Load
Size
<
Problem
,
BDataType
,
NPerBlock
>
();
constexpr
index_t
KPack
=
GetVectorSize
B
<
Problem
>
();
static_assert
(
KPack
%
K3
==
0
);
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
...
@@ -286,6 +385,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -286,6 +385,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
...
@@ -300,6 +400,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -300,6 +400,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
...
@@ -312,7 +413,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -312,7 +413,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
else
else
{
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
K1
=
GetVectorSizeB
<
Problem
>
(
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
// coalesce reading for each blocks
...
@@ -322,6 +423,9 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -322,6 +423,9 @@ struct UniversalGemmPipelineAgBgCrPolicy
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
static_assert
(
N0
*
N1
*
N2
==
NPerBlock
,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!"
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
1
>
,
...
@@ -336,6 +440,9 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -336,6 +440,9 @@ struct UniversalGemmPipelineAgBgCrPolicy
{
{
constexpr
index_t
N0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
NPerBlock
/
(
N2
*
N0
);
constexpr
index_t
N1
=
NPerBlock
/
(
N2
*
N0
);
static_assert
(
N0
*
N1
*
N2
==
NPerBlock
,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!"
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
1
>
,
...
@@ -447,22 +554,21 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -447,22 +554,21 @@ struct UniversalGemmPipelineAgBgCrPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
{
using
AccDataType
=
float
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
BDataType
,
Acc
DataType
,
typename
Problem
::
C
DataType
,
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I2
),
WarpTile
::
at
(
I2
),
TransposeC
>
;
Problem
::
TransposeC
>
;
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
typename
Problem
::
CDataType
,
BlockWarps
,
BlockWarps
,
WarpGemm
>
;
WarpGemm
>
;
return
Block
GemmASmemBSmemCRegV1
<
Problem
,
BlockGemmPolicy
>
{};
return
Block
UniversalGemmAsBsCr
<
Problem
,
BlockGemmPolicy
>
{};
}
}
};
};
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
View file @
241baec1
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -19,11 +20,34 @@ struct TileGemmTraits
...
@@ -19,11 +20,34 @@ struct TileGemmTraits
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadK
=
kPadK_
;
static
constexpr
bool
kPadK
=
kPadK_
;
// TODO this can't be hardcoded here! Should be in policy!
static
constexpr
int
_VectorSize
=
16
;
static
constexpr
int
_VectorSize
=
16
;
using
ALayout
=
ALayout_
;
using
ALayout
=
ALayout_
;
using
BLayout
=
BLayout_
;
using
BLayout
=
BLayout_
;
using
CLayout
=
CLayout_
;
using
CLayout
=
CLayout_
;
static
constexpr
bool
TransposeC
=
false
;
};
template
<
bool
kPadM_
,
bool
kPadN_
,
bool
kPadK_
,
typename
ALayout_
,
typename
BLayout_
,
typename
CLayout_
,
bool
TransposeC_
=
false
>
struct
TileGemmUniversalTraits
{
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadK
=
kPadK_
;
using
ALayout
=
ALayout_
;
using
BLayout
=
BLayout_
;
using
CLayout
=
CLayout_
;
static
constexpr
bool
TransposeC
=
TransposeC_
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment