Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
dd21c599
"...composable_kernel_rocm.git" did not exist on "5f2c89e8b43d670e3405a4f17ff475d25960f9b3"
Commit
dd21c599
authored
Feb 11, 2025
by
Jakub Piasecki
Browse files
tmp save
parent
f23a2e2a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
124 additions
and
42 deletions
+124
-42
example/ck_tile/03_gemm/gemm.hpp
example/ck_tile/03_gemm/gemm.hpp
+43
-6
example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp
...03_gemm/instances/gemm_universal_comp_instance_common.hpp
+39
-18
example/ck_tile/03_gemm/instances/gemm_universal_mem_instance_common.hpp
.../03_gemm/instances/gemm_universal_mem_instance_common.hpp
+42
-18
No files found.
example/ck_tile/03_gemm/gemm.hpp
View file @
dd21c599
...
...
@@ -23,6 +23,33 @@ struct GemmBasicTypeConfig<ck_tile::half_t>
// ToDo: Add more bias config to support different categories of GEMM.
};
template
<
>
struct
GemmBasicTypeConfig
<
ck_tile
::
bf16_t
>
{
using
ADataType
=
ck_tile
::
bf16_t
;
using
BDataType
=
ck_tile
::
bf16_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
bf16_t
;
};
template
<
>
struct
GemmBasicTypeConfig
<
ck_tile
::
fp8_t
>
{
using
ADataType
=
ck_tile
::
fp8_t
;
using
BDataType
=
ck_tile
::
fp8_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
half_t
;
};
template
<
>
struct
GemmBasicTypeConfig
<
ck_tile
::
bf8_t
>
{
using
ADataType
=
ck_tile
::
bf8_t
;
using
BDataType
=
ck_tile
::
bf8_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck_tile
::
half_t
;
};
template
<
typename
T
>
struct
DataTypeTraits
;
...
...
@@ -44,13 +71,23 @@ struct DataTypeTraits<ck_tile::half_t>
static
constexpr
const
char
*
name
=
"fp16"
;
};
using
Types
=
GemmBasicTypeConfig
<
ck_tile
::
half_t
>
;
template
<
>
struct
DataTypeTraits
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
// Specific type aliases for easy access
using
ADataType
=
Types
::
ADataType
;
using
BDataType
=
Types
::
BDataType
;
using
AccDataType
=
Types
::
AccDataType
;
using
CDataType
=
Types
::
CDataType
;
template
<
>
struct
DataTypeTraits
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
DataTypeTraits
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
/** \brief Struct used for specifying desired gemm details*/
struct
gemm_traits
...
...
example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp
View file @
dd21c599
...
...
@@ -10,34 +10,54 @@ using S = ck_tile::stream_config;
template
<
typename
Traits_
>
float
gemm_
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
constexpr
bool
TransposeC
=
false
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
ck_tile
::
index_t
TileParitionerGroupNum
=
8
;
constexpr
ck_tile
::
index_t
TileParitionerM01
=
4
;
using
GemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
Traits_
::
M_Tile
,
Traits_
::
N_Tile
,
Traits_
::
K_Tile
>
,
ck_tile
::
sequence
<
Traits_
::
M_Warp
,
Traits_
::
N_Warp
,
Traits_
::
K_Warp
>
,
ck_tile
::
sequence
<
Traits_
::
M_Warp_Tile
,
Traits_
::
N_Warp_Tile
,
Traits_
::
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmSpatiallyLocalTilePartitioner
<
GemmShape
,
TileParitionerGroupNum
,
TileParitionerM01
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
typename
Traits_
::
AccDataType
,
typename
Traits_
::
CDataType
,
Traits_
::
kPadM
,
Traits_
::
kPadN
>>
;
constexpr
bool
TransposeC
=
false
;
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
Traits_
::
kPadM
,
Traits_
::
kPadN
,
Traits_
::
kPadK
,
Traits_
::
ALayout
,
Traits_
::
BLayout
,
Traits_
::
CLayout
,
TransposeC
>
;
using
GemmTraits
=
ck_tile
::
TileGemmTraits
<
Traits_
::
kPadM
,
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
Traits_
::
kPadM
,
Traits_
::
kPadN
,
Traits_
::
kPadK
,
typename
Traits_
::
ALayout
,
typename
Traits_
::
BLayout
,
typename
Traits_
::
CLayout
,
TransposeC
>
;
using
GemmTraits
=
ck_tile
::
TileGemmTraits
<
Traits_
::
kPadM
,
Traits_
::
kPadN
,
Traits_
::
kPadK
,
typename
Traits_
::
ALayout
,
typename
Traits_
::
BLayout
,
typename
Traits_
::
CLayout
>
;
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
ck_tile
::
GemmPipelineProblem
<
typename
Traits_
::
ADataType
,
typename
Traits_
::
BDataType
,
typename
Traits_
::
AccDataType
,
GemmShape
,
GemmTraits
>>
;
constexpr
int
kBlockPerCu
=
1
;
using
GemmPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
typename
Traits_
::
ADataType
,
typename
Traits_
::
BDataType
,
typename
Traits_
::
AccDataType
,
GemmShape
,
GemmTraits
>
;
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
GemmPipelineProblem
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
typename
Traits_
::
AccDataType
,
typename
Traits_
::
CDataType
,
typename
Traits_
::
CLayout
,
GemmPipelineProblem
::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
Traits_
::
M_Warp
,
Traits_
::
N_Warp
,
Traits_
::
M_Warp_Tile
,
Traits_
::
N_Warp_Tile
,
Traits_
::
K_Warp_Tile
,
TransposeC
>>
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
Traits_
::
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
Traits_
::
K_Tile
;
...
...
@@ -59,7 +79,8 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
GemmUniversalTraits
,
ck_tile
::
GemmPipelineScheduler
::
Intrawave
,
has_hot_loop_v
,
tail_number_v
>
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
;
tail_number_v
>
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
...
...
example/ck_tile/03_gemm/instances/gemm_universal_mem_instance_common.hpp
View file @
dd21c599
...
...
@@ -10,31 +10,54 @@ using S = ck_tile::stream_config;
template
<
typename
Traits_
>
float
gemm_
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
constexpr
bool
TransposeC
=
false
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
ck_tile
::
index_t
TileParitionerGroupNum
=
8
;
constexpr
ck_tile
::
index_t
TileParitionerM01
=
4
;
using
GemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
Traits_
::
M_Tile
,
Traits_
::
N_Tile
,
Traits_
::
K_Tile
>
,
ck_tile
::
sequence
<
Traits_
::
M_Warp
,
Traits_
::
N_Warp
,
Traits_
::
K_Warp
>
,
ck_tile
::
sequence
<
Traits_
::
M_Warp_Tile
,
Traits_
::
N_Warp_Tile
,
Traits_
::
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmSpatiallyLocalTilePartitioner
<
GemmShape
,
TileParitionerGroupNum
,
TileParitionerM01
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
typename
Traits_
::
AccDataType
,
typename
Traits_
::
CDataType
,
Traits_
::
kPadM
,
Traits_
::
kPadN
>>
;
using
GemmTraits
=
ck_tile
::
TileGemmTraits
<
Traits_
::
kPadM
,
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
Traits_
::
kPadM
,
Traits_
::
kPadN
,
Traits_
::
kPadK
,
typename
Traits_
::
ALayout
,
typename
Traits_
::
BLayout
,
typename
Traits_
::
CLayout
,
TransposeC
>
;
using
GemmTraits
=
ck_tile
::
TileGemmTraits
<
Traits_
::
kPadM
,
Traits_
::
kPadN
,
Traits_
::
kPadK
,
typename
Traits_
::
ALayout
,
typename
Traits_
::
BLayout
,
typename
Traits_
::
CLayout
>
;
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
ck_tile
::
GemmPipelineProblem
<
typename
Traits_
::
ADataType
,
typename
Traits_
::
BDataType
,
typename
Traits_
::
AccDataType
,
GemmShape
,
GemmTraits
>>
;
typename
Traits_
::
CLayout
>
;
constexpr
int
kBlockPerCu
=
1
;
using
GemmPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
typename
Traits_
::
ADataType
,
typename
Traits_
::
BDataType
,
typename
Traits_
::
AccDataType
,
GemmShape
,
GemmTraits
>
;
using
BaseGemmPipeline
=
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
GemmPipelineProblem
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
typename
Traits_
::
AccDataType
,
typename
Traits_
::
CDataType
,
typename
Traits_
::
CLayout
,
GemmPipelineProblem
::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
Traits_
::
M_Warp
,
Traits_
::
N_Warp
,
Traits_
::
M_Warp_Tile
,
Traits_
::
N_Warp_Tile
,
Traits_
::
K_Warp_Tile
,
TransposeC
>>
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
Traits_
::
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
Traits_
::
K_Tile
;
...
...
@@ -53,10 +76,11 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
typename
Traits_
::
BDataType
,
typename
Traits_
::
AccDataType
,
GemmShape
,
GemmTraits
,
ck_tile
::
GemmPipelineScheduler
::
Int
e
rwave
,
Gemm
Universal
Traits
,
ck_tile
::
GemmPipelineScheduler
::
Intr
a
wave
,
has_hot_loop_v
,
tail_number_v
>>
;
tail_number_v
>
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment