Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f2c1fa7f
Commit
f2c1fa7f
authored
Feb 13, 2025
by
ThomasNing
Browse files
Merge branch 'develop' of
https://github.com/ROCm/composable_kernel
into develop
parents
4658f2f6
0e5e29c4
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
232 additions
and
154 deletions
+232
-154
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+1
-1
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+4
-0
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+143
-132
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+3
-0
test/ck_tile/gemm/test_gemm_pipeline.cpp
test/ck_tile/gemm/test_gemm_pipeline.cpp
+10
-5
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+71
-16
No files found.
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
f2c1fa7f
...
@@ -338,7 +338,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -338,7 +338,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
Row
Major
>
);
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
Column
Major
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
f2c1fa7f
...
@@ -36,6 +36,8 @@ struct GemmPipelineProblemBase
...
@@ -36,6 +36,8 @@ struct GemmPipelineProblemBase
static
constexpr
bool
kPadN
=
Traits
::
kPadN
;
static
constexpr
bool
kPadN
=
Traits
::
kPadN
;
static
constexpr
bool
kPadK
=
Traits
::
kPadK
;
static
constexpr
bool
kPadK
=
Traits
::
kPadK
;
static
constexpr
bool
DoubleSmemBuffer
=
Traits
::
DoubleSmemBuffer
;
static
constexpr
auto
Scheduler
=
GemmPipelineScheduler
::
Default
;
static
constexpr
auto
Scheduler
=
GemmPipelineScheduler
::
Default
;
static
constexpr
index_t
VectorLoadSize
=
Traits
::
_VectorSize
;
static
constexpr
index_t
VectorLoadSize
=
Traits
::
_VectorSize
;
...
@@ -173,6 +175,8 @@ struct UniversalGemmPipelineProblem
...
@@ -173,6 +175,8 @@ struct UniversalGemmPipelineProblem
static
constexpr
bool
kPadN
=
Traits
::
kPadN
;
static
constexpr
bool
kPadN
=
Traits
::
kPadN
;
static
constexpr
bool
kPadK
=
Traits
::
kPadK
;
static
constexpr
bool
kPadK
=
Traits
::
kPadK
;
static
constexpr
bool
DoubleSmemBuffer
=
Traits
::
DoubleSmemBuffer
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
TailNum
=
TailNum_
;
static
constexpr
auto
TailNum
=
TailNum_
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
f2c1fa7f
...
@@ -9,8 +9,8 @@
...
@@ -9,8 +9,8 @@
namespace
ck_tile
{
namespace
ck_tile
{
// UniversalGemm Policy
template
<
typename
Derived
>
struct
UniversalGemm
PipelineAgBgCr
Policy
struct
UniversalGemm
Base
Policy
{
{
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
...
@@ -113,7 +113,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -113,7 +113,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
{
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetBlockGemm
<
Problem
>
())
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Derived
::
template
GetBlockGemm
<
Problem
>())
>
;
using
WG
=
typename
BlockGemm
::
WarpGemm
;
using
WG
=
typename
BlockGemm
::
WarpGemm
;
constexpr
bool
TransposeC
=
Problem
::
TransposeC
;
constexpr
bool
TransposeC
=
Problem
::
TransposeC
;
...
@@ -166,10 +166,116 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -166,10 +166,116 @@ struct UniversalGemmPipelineAgBgCrPolicy
}
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Problem
::
TransposeC
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeA
<
Problem
>
();
// Tile: MPerBlock X KPerBlock
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
MPerBlock
,
KPerBlock
,
VecLoadSize
,
ATileAccessPattern
>
;
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
}
// Tile: KPerBlock X MPerBlock
else
{
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
KPerBlock
,
MPerBlock
,
VecLoadSize
,
ATileAccessPattern
>
;
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeB
<
Problem
>
();
// Tile: KPerBlock X NPerBlock
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
KPerBlock
,
NPerBlock
,
VecLoadSize
,
BTileAccessPattern
>
;
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
}
// Tile: NPerBlock X KPerBlock
else
{
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
NPerBlock
,
KPerBlock
,
VecLoadSize
,
BTileAccessPattern
>
;
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegTileDistribution
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeA
<
Problem
>
();
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
KPerBlock
,
MPerBlock
,
VecLoadSize
,
ATileAccessPattern
>
;
return
TileEncodingPattern
::
MakeShuffled2DStaticTileDistribution
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegTileDistribution
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeB
<
Problem
>
();
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
KPerBlock
,
NPerBlock
,
VecLoadSize
,
BTileAccessPattern
>
;
return
TileEncodingPattern
::
MakeShuffled2DStaticTileDistribution
();
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
{
using
BlockGemm
=
decltype
(
GetBlockGemm
<
Problem
>
());
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Derived
::
template
GetBlockGemm
<
Problem
>())
>
;
constexpr
index_t
KPack
=
BlockGemm
::
Traits
::
KPack
;
constexpr
index_t
KPack
=
BlockGemm
::
Traits
::
KPack
;
return
KPack
;
return
KPack
;
}
}
...
@@ -177,11 +283,43 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -177,11 +283,43 @@ struct UniversalGemmPipelineAgBgCrPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
{
{
using
BlockGemm
=
decltype
(
GetBlockGemm
<
Problem
>
());
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Derived
::
template
GetBlockGemm
<
Problem
>())
>
;
constexpr
index_t
KPack
=
BlockGemm
::
Traits
::
KPack
;
constexpr
index_t
KPack
=
BlockGemm
::
Traits
::
KPack
;
return
KPack
;
return
KPack
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
{
constexpr
auto
a_lds_desc
=
Derived
::
template
MakeALdsBlockDescriptor
<
Problem
>();
constexpr
index_t
smem_size_a
=
integer_least_multiple
(
sizeof
(
typename
Problem
::
ADataType
)
*
a_lds_desc
.
get_element_space_size
(),
16
);
return
smem_size_a
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
{
constexpr
auto
b_lds_desc
=
Derived
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
constexpr
index_t
smem_size_b
=
integer_least_multiple
(
sizeof
(
typename
Problem
::
BDataType
)
*
b_lds_desc
.
get_element_space_size
(),
16
);
return
smem_size_b
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
return
smem_size_a
+
smem_size_b
;
}
};
// UniversalGemm Policy
struct
UniversalGemmPipelineAgBgCrPolicy
:
public
UniversalGemmBasePolicy
<
UniversalGemmPipelineAgBgCrPolicy
>
{
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
{
...
@@ -421,133 +559,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -421,133 +559,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
#endif
#endif
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
{
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_a
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
{
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_b
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
index_t
smem_size
=
0
;
smem_size
+=
smem_size_a
+
smem_size_b
;
return
smem_size
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeA
<
Problem
>
();
// Tile: MPerBlock X KPerBlock
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
MPerBlock
,
KPerBlock
,
VecLoadSize
,
ATileAccessPattern
>
;
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
}
// Tile: KPerBlock X MPerBlock
else
{
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
KPerBlock
,
MPerBlock
,
VecLoadSize
,
ATileAccessPattern
>
;
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeB
<
Problem
>
();
// Tile: KPerBlock X NPerBlock
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
KPerBlock
,
NPerBlock
,
VecLoadSize
,
BTileAccessPattern
>
;
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
}
// Tile: NPerBlock X KPerBlock
else
{
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
NPerBlock
,
KPerBlock
,
VecLoadSize
,
BTileAccessPattern
>
;
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegTileDistribution
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeA
<
Problem
>
();
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
KPerBlock
,
MPerBlock
,
VecLoadSize
,
ATileAccessPattern
>
;
return
TileEncodingPattern
::
MakeShuffled2DStaticTileDistribution
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegTileDistribution
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeB
<
Problem
>
();
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
KPerBlock
,
NPerBlock
,
VecLoadSize
,
BTileAccessPattern
>
;
return
TileEncodingPattern
::
MakeShuffled2DStaticTileDistribution
();
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
{
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
View file @
f2c1fa7f
...
@@ -32,6 +32,7 @@ struct TileGemmTraits
...
@@ -32,6 +32,7 @@ struct TileGemmTraits
template
<
bool
kPadM_
,
template
<
bool
kPadM_
,
bool
kPadN_
,
bool
kPadN_
,
bool
kPadK_
,
bool
kPadK_
,
bool
DoubleSmemBuffer_
,
typename
ALayout_
,
typename
ALayout_
,
typename
BLayout_
,
typename
BLayout_
,
typename
CLayout_
,
typename
CLayout_
,
...
@@ -42,6 +43,8 @@ struct TileGemmUniversalTraits
...
@@ -42,6 +43,8 @@ struct TileGemmUniversalTraits
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadK
=
kPadK_
;
static
constexpr
bool
kPadK
=
kPadK_
;
static
constexpr
bool
DoubleSmemBuffer
=
DoubleSmemBuffer_
;
using
ALayout
=
ALayout_
;
using
ALayout
=
ALayout_
;
using
BLayout
=
BLayout_
;
using
BLayout
=
BLayout_
;
using
CLayout
=
CLayout_
;
using
CLayout
=
CLayout_
;
...
...
test/ck_tile/gemm/test_gemm_pipeline.cpp
View file @
f2c1fa7f
...
@@ -17,22 +17,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
...
@@ -17,22 +17,27 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
using
Interwave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
using
Interwave
=
ck_tile
::
integral_constant
<
ck_tile
::
GemmPipelineScheduler
,
ck_tile
::
GemmPipelineScheduler
::
Interwave
>
;
ck_tile
::
GemmPipelineScheduler
::
Interwave
>
;
using
Mem
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Mem
>
;
using
Mem
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Mem
>
;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
using
CompV3
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
CompV3
>
;
using
CompV4
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
CompV4
>
;
// clang-format off
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
CompV3
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
CompV4
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
CompV3
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
CompV4
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
CompV3
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
CompV4
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Mem
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
CompV3
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
CompV4
>
,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Interwave
,
Mem
>
>
;
>
;
// clang-format on
// clang-format on
...
...
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
View file @
f2c1fa7f
...
@@ -14,7 +14,32 @@
...
@@ -14,7 +14,32 @@
enum
struct
GemmPipelineType
enum
struct
GemmPipelineType
{
{
Mem
,
Mem
,
Comp
CompV3
,
CompV4
};
template
<
GemmPipelineType
PT
,
typename
Problem
>
struct
GemmPipelineTypeSelector
;
template
<
typename
Problem
>
struct
GemmPipelineTypeSelector
<
GemmPipelineType
::
Mem
,
Problem
>
{
using
base_pipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
using
pipeline
=
ck_tile
::
GemmPipelineAgBgCrMem
<
Problem
>
;
};
template
<
typename
Problem
>
struct
GemmPipelineTypeSelector
<
GemmPipelineType
::
CompV3
,
Problem
>
{
using
base_pipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
;
using
pipeline
=
ck_tile
::
GemmPipelineAgBgCrCompV3
<
Problem
>
;
};
template
<
typename
Problem
>
struct
GemmPipelineTypeSelector
<
GemmPipelineType
::
CompV4
,
Problem
>
{
using
base_pipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrCompV4
<
Problem
>
;
using
pipeline
=
ck_tile
::
GemmPipelineAgBgCrCompV4
<
Problem
>
;
};
};
template
<
typename
Tuple
>
template
<
typename
Tuple
>
...
@@ -36,8 +61,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -36,8 +61,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
void
invoke_gemm
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
void
invoke_gemm
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
// TODO: This should be parameterized in tests
// TODO: This should be parameterized in tests
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
128
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
...
@@ -52,6 +77,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -52,6 +77,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr
bool
kPadN
=
PadN
;
constexpr
bool
kPadN
=
PadN
;
constexpr
bool
kPadK
=
PadK
;
constexpr
bool
kPadK
=
PadK
;
constexpr
bool
DoubleSmemBuffer
=
(
PipelineType
==
GemmPipelineType
::
CompV4
)
?
true
:
false
;
// TODO: For now - but this should also be a test parameter
// TODO: For now - but this should also be a test parameter
constexpr
bool
TransposeC
=
false
;
constexpr
bool
TransposeC
=
false
;
...
@@ -69,16 +96,20 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -69,16 +96,20 @@ class TestCkTileGemmPipeline : public ::testing::Test
GemmSpatiallyLocalTilePartitioner
<
GemmShape
,
TileParitionerGroupNum
,
TileParitionerM01
>
;
GemmSpatiallyLocalTilePartitioner
<
GemmShape
,
TileParitionerGroupNum
,
TileParitionerM01
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
GemmUniversalTraits
=
ck_tile
::
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
kPadM
,
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
kPadN
,
kPadK
,
DoubleSmemBuffer
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
using
GemmPipelineProblem
=
using
GemmPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
using
BaseGemmPipeline
=
using
BaseGemmPipeline
=
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
typename
GemmPipelineTypeSelector
<
PipelineType
,
GemmPipelineProblem
>::
base_pipeline
;
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
GemmPipelineProblem
>
,
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
GemmPipelineProblem
>>
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
...
@@ -99,12 +130,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -99,12 +130,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
has_hot_loop_v
,
has_hot_loop_v
,
tail_number_v
>
;
tail_number_v
>
;
using
GemmPipeline
=
std
::
conditional_t
<
using
GemmPipeline
=
PipelineType
==
GemmPipelineType
::
Mem
,
typename
GemmPipelineTypeSelector
<
PipelineType
,
UniversalGemmProblem
>::
pipeline
;
ck_tile
::
GemmPipelineAgBgCrMem
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
,
ck_tile
::
GemmPipelineAgBgCrCompV3
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
...
@@ -145,7 +172,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -145,7 +172,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
if
(
has_hot_loop
)
if
(
has_hot_loop
)
{
{
if
constexpr
(
PipelineType
==
GemmPipelineType
::
Comp
)
if
constexpr
(
PipelineType
==
GemmPipelineType
::
Comp
V3
)
{
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
{
...
@@ -235,6 +262,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -235,6 +262,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
}
}
}
}
}
}
if
constexpr
(
PipelineType
==
GemmPipelineType
::
CompV4
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Three
>
{});
}
else
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
}
}
}
}
else
else
{
{
...
@@ -258,7 +301,19 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -258,7 +301,19 @@ class TestCkTileGemmPipeline : public ::testing::Test
public:
public:
std
::
vector
<
int
>
k_batches_
;
std
::
vector
<
int
>
k_batches_
;
void
SetUp
()
override
{
k_batches_
=
{
1
,
2
};
}
void
SetUp
()
override
{
if
constexpr
(
PipelineType
==
GemmPipelineType
::
CompV4
)
{
// Only do k_batch = 1 when pipeline is CompV4
k_batches_
=
{
1
};
}
else
{
// Otherwise, use k_batch = 1 and 2
k_batches_
=
{
1
,
2
};
}
}
template
<
bool
PadM
=
true
,
bool
PadN
=
true
,
bool
PadK
=
true
>
template
<
bool
PadM
=
true
,
bool
PadN
=
true
,
bool
PadK
=
true
>
void
Run
(
const
int
M
,
void
Run
(
const
int
M
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment