Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
da8e50dd
Commit
da8e50dd
authored
Feb 11, 2025
by
Jakub Piasecki
Browse files
tmp save
parent
dd21c599
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
171 additions
and
8 deletions
+171
-8
example/ck_tile/03_gemm/instances/gemm_api.cpp
example/ck_tile/03_gemm/instances/gemm_api.cpp
+81
-4
example/ck_tile/03_gemm/instances/gemm_universal_comp_fp8_fp8_fp8_km_kn_mn.cpp
...mm/instances/gemm_universal_comp_fp8_fp8_fp8_km_kn_mn.cpp
+11
-0
example/ck_tile/03_gemm/instances/gemm_universal_comp_fp8_fp8_fp8_km_nk_mn.cpp
...mm/instances/gemm_universal_comp_fp8_fp8_fp8_km_nk_mn.cpp
+11
-0
example/ck_tile/03_gemm/instances/gemm_universal_comp_fp8_fp8_fp8_mk_kn_mn.cpp
...mm/instances/gemm_universal_comp_fp8_fp8_fp8_mk_kn_mn.cpp
+10
-0
example/ck_tile/03_gemm/instances/gemm_universal_comp_fp8_fp8_fp8_mk_nk_mn.cpp
...mm/instances/gemm_universal_comp_fp8_fp8_fp8_mk_nk_mn.cpp
+11
-0
example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp
.../instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp
+1
-1
example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp
.../instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp
+1
-1
example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp
...emm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp
+1
-1
example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp
...emm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp
+1
-1
example/ck_tile/03_gemm/instances/gemm_universal_mem_fp8_fp8_fp8_km_kn_mn.cpp
...emm/instances/gemm_universal_mem_fp8_fp8_fp8_km_kn_mn.cpp
+11
-0
example/ck_tile/03_gemm/instances/gemm_universal_mem_fp8_fp8_fp8_km_nk_mn.cpp
...emm/instances/gemm_universal_mem_fp8_fp8_fp8_km_nk_mn.cpp
+11
-0
example/ck_tile/03_gemm/instances/gemm_universal_mem_fp8_fp8_fp8_mk_kn_mn.cpp
...emm/instances/gemm_universal_mem_fp8_fp8_fp8_mk_kn_mn.cpp
+10
-0
example/ck_tile/03_gemm/instances/gemm_universal_mem_fp8_fp8_fp8_mk_nk_mn.cpp
...emm/instances/gemm_universal_mem_fp8_fp8_fp8_mk_nk_mn.cpp
+11
-0
No files found.
example/ck_tile/03_gemm/instances/gemm_api.cpp
View file @
da8e50dd
...
@@ -9,6 +9,8 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
...
@@ -9,6 +9,8 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using
FP32
=
float
;
using
FP32
=
float
;
using
FP16
=
ck_tile
::
half_t
;
using
FP16
=
ck_tile
::
half_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
BF16
=
ck_tile
::
bf16_t
;
using
FP8
=
ck_tile
::
fp8_t
;
using
BF8
=
ck_tile
::
bf8_t
;
float
gemm
(
const
gemm_traits
&
t
,
const
ck_tile
::
GemmHostArgs
&
a
,
const
ck_tile
::
stream_config
&
s
)
float
gemm
(
const
gemm_traits
&
t
,
const
ck_tile
::
GemmHostArgs
&
a
,
const
ck_tile
::
stream_config
&
s
)
{
{
...
@@ -27,7 +29,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
...
@@ -27,7 +29,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{
{
// clang-format off
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return
gemm_
<
gemm_traits_
<
FP16
,
FP16
,
FP32
,
FP16
,
Row
,
Row
,
Row
,
128
,
32
,
64
,
4
,
1
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>>
(
a
,
s
);
return
gemm_
<
gemm_traits_
<
FP16
,
FP16
,
FP32
,
FP16
,
Row
,
Row
,
Row
,
128
,
128
,
32
,
2
,
2
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>>
(
a
,
s
);
// clang-format on
// clang-format on
}
}
}
}
...
@@ -44,7 +46,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
...
@@ -44,7 +46,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{
{
// clang-format off
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return
gemm_
<
gemm_traits_
<
FP16
,
FP16
,
FP32
,
FP16
,
Row
,
Col
,
Row
,
128
,
32
,
64
,
4
,
1
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>>
(
a
,
s
);
return
gemm_
<
gemm_traits_
<
FP16
,
FP16
,
FP32
,
FP16
,
Row
,
Col
,
Row
,
128
,
128
,
32
,
2
,
2
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>>
(
a
,
s
);
// clang-format on
// clang-format on
}
}
}
}
...
@@ -102,7 +104,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
...
@@ -102,7 +104,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{
{
// clang-format off
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return
gemm_
<
gemm_traits_
<
BF16
,
BF16
,
FP32
,
BF16
,
Row
,
Row
,
Row
,
128
,
32
,
64
,
4
,
1
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>>
(
a
,
s
);
return
gemm_
<
gemm_traits_
<
BF16
,
BF16
,
FP32
,
BF16
,
Row
,
Row
,
Row
,
128
,
128
,
32
,
2
,
2
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>>
(
a
,
s
);
// clang-format on
// clang-format on
}
}
}
}
...
@@ -119,7 +121,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
...
@@ -119,7 +121,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{
{
// clang-format off
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return
gemm_
<
gemm_traits_
<
BF16
,
BF16
,
FP32
,
BF16
,
Row
,
Col
,
Row
,
128
,
32
,
64
,
4
,
1
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>>
(
a
,
s
);
return
gemm_
<
gemm_traits_
<
BF16
,
BF16
,
FP32
,
BF16
,
Row
,
Col
,
Row
,
128
,
128
,
32
,
2
,
2
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>>
(
a
,
s
);
// clang-format on
// clang-format on
}
}
}
}
...
@@ -162,6 +164,81 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
...
@@ -162,6 +164,81 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
throw
std
::
runtime_error
(
"Wrong! ColumnMajor layout not supported for C Matrix!
\n
"
);
throw
std
::
runtime_error
(
"Wrong! ColumnMajor layout not supported for C Matrix!
\n
"
);
}
}
}
}
else
if
(
t
.
data_type
.
compare
(
"fp8"
)
==
0
)
{
if
(
t
.
is_a_rowmajor
&&
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
if
(
a
.
M
>
512
)
{
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return
gemm_
<
gemm_traits_
<
FP8
,
FP8
,
FP32
,
FP8
,
Row
,
Row
,
Row
,
256
,
256
,
64
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>>
(
a
,
s
);
// clang-format on
}
else
{
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return
gemm_
<
gemm_traits_
<
FP8
,
FP8
,
FP32
,
FP8
,
Row
,
Row
,
Row
,
128
,
128
,
64
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>>
(
a
,
s
);
// clang-format on
}
}
else
if
(
t
.
is_a_rowmajor
&&
!
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
if
(
a
.
M
>
512
)
{
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return
gemm_
<
gemm_traits_
<
FP8
,
FP8
,
FP32
,
FP8
,
Row
,
Col
,
Row
,
256
,
256
,
64
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>>
(
a
,
s
);
// clang-format on
}
else
{
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return
gemm_
<
gemm_traits_
<
FP8
,
FP8
,
FP32
,
FP8
,
Row
,
Col
,
Row
,
128
,
128
,
64
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>>
(
a
,
s
);
// clang-format on
}
}
else
if
(
!
t
.
is_a_rowmajor
&&
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
if
(
a
.
M
>
512
)
{
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return
gemm_
<
gemm_traits_
<
FP8
,
FP8
,
FP32
,
FP8
,
Col
,
Row
,
Row
,
256
,
256
,
64
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>>
(
a
,
s
);
// clang-format on
}
else
{
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return
gemm_
<
gemm_traits_
<
FP8
,
FP8
,
FP32
,
FP8
,
Col
,
Row
,
Row
,
128
,
128
,
64
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>>
(
a
,
s
);
// clang-format on
}
}
else
if
(
!
t
.
is_a_rowmajor
&&
!
t
.
is_b_rowmajor
&&
t
.
is_c_rowmajor
)
{
if
(
a
.
M
>
512
)
{
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return
gemm_
<
gemm_traits_
<
FP8
,
FP8
,
FP32
,
FP8
,
Col
,
Col
,
Row
,
256
,
256
,
64
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>>
(
a
,
s
);
// clang-format on
}
else
{
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return
gemm_
<
gemm_traits_
<
FP8
,
FP8
,
FP32
,
FP8
,
Col
,
Col
,
Row
,
128
,
128
,
64
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>>
(
a
,
s
);
// clang-format on
}
}
else
{
throw
std
::
runtime_error
(
"Wrong! ColumnMajor layout not supported for C Matrix!
\n
"
);
}
}
else
else
{
{
throw
std
::
runtime_error
(
"Wrong! DataTypes not supported!
\n
"
);
throw
std
::
runtime_error
(
"Wrong! DataTypes not supported!
\n
"
);
...
...
example/ck_tile/03_gemm/instances/gemm_universal_comp_fp8_fp8_fp8_km_kn_mn.cpp
0 → 100644
View file @
da8e50dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
// clang-format off
template
float
gemm_
<
gemm_traits_
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
ck_tile
::
fp8_t
,
Col
,
Row
,
Row
,
256
,
256
,
64
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
// clang-format on
example/ck_tile/03_gemm/instances/gemm_universal_comp_fp8_fp8_fp8_km_nk_mn.cpp
0 → 100644
View file @
da8e50dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
// clang-format off
template
float
gemm_
<
gemm_traits_
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
ck_tile
::
fp8_t
,
Col
,
Col
,
Row
,
256
,
256
,
64
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
// clang-format on
example/ck_tile/03_gemm/instances/gemm_universal_comp_fp8_fp8_fp8_mk_kn_mn.cpp
0 → 100644
View file @
da8e50dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
// clang-format off
template
float
gemm_
<
gemm_traits_
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
ck_tile
::
fp8_t
,
Row
,
Row
,
Row
,
256
,
256
,
64
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
// clang-format on
example/ck_tile/03_gemm/instances/gemm_universal_comp_fp8_fp8_fp8_mk_nk_mn.cpp
0 → 100644
View file @
da8e50dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
// clang-format off
template
float
gemm_
<
gemm_traits_
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
ck_tile
::
fp8_t
,
Row
,
Col
,
Row
,
256
,
256
,
64
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
// clang-format on
example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp
View file @
da8e50dd
...
@@ -6,5 +6,5 @@
...
@@ -6,5 +6,5 @@
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
// clang-format off
// clang-format off
template
float
gemm_
<
gemm_traits_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
ck_tile
::
bf16_t
,
Row
,
Row
,
Row
,
128
,
32
,
64
,
4
,
1
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
template
float
gemm_
<
gemm_traits_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
ck_tile
::
bf16_t
,
Row
,
Row
,
Row
,
128
,
128
,
32
,
2
,
2
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
// clang-format on
// clang-format on
example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp
View file @
da8e50dd
...
@@ -7,5 +7,5 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
...
@@ -7,5 +7,5 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
// clang-format off
// clang-format off
template
float
gemm_
<
gemm_traits_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
ck_tile
::
bf16_t
,
Row
,
Col
,
Row
,
128
,
32
,
64
,
4
,
1
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
template
float
gemm_
<
gemm_traits_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
ck_tile
::
bf16_t
,
Row
,
Col
,
Row
,
128
,
128
,
32
,
2
,
2
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
// clang-format on
// clang-format on
example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp
View file @
da8e50dd
...
@@ -6,5 +6,5 @@
...
@@ -6,5 +6,5 @@
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
// clang-format off
// clang-format off
template
float
gemm_
<
gemm_traits_
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
ck_tile
::
half_t
,
Row
,
Row
,
Row
,
128
,
32
,
64
,
4
,
1
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
template
float
gemm_
<
gemm_traits_
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
ck_tile
::
half_t
,
Row
,
Row
,
Row
,
128
,
128
,
32
,
2
,
2
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
// clang-format on
// clang-format on
example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp
View file @
da8e50dd
...
@@ -7,5 +7,5 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
...
@@ -7,5 +7,5 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
// clang-format off
// clang-format off
template
float
gemm_
<
gemm_traits_
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
ck_tile
::
half_t
,
Row
,
Col
,
Row
,
128
,
32
,
64
,
4
,
1
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
template
float
gemm_
<
gemm_traits_
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
ck_tile
::
half_t
,
Row
,
Col
,
Row
,
128
,
128
,
32
,
2
,
2
,
1
,
32
,
32
,
8
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
// clang-format on
// clang-format on
example/ck_tile/03_gemm/instances/gemm_universal_mem_fp8_fp8_fp8_km_kn_mn.cpp
0 → 100644
View file @
da8e50dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
// clang-format off
template
float
gemm_
<
gemm_traits_
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
ck_tile
::
fp8_t
,
Col
,
Row
,
Row
,
128
,
128
,
64
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
// clang-format on
example/ck_tile/03_gemm/instances/gemm_universal_mem_fp8_fp8_fp8_km_nk_mn.cpp
0 → 100644
View file @
da8e50dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
// clang-format off
template
float
gemm_
<
gemm_traits_
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
ck_tile
::
fp8_t
,
Col
,
Col
,
Row
,
128
,
128
,
64
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
// clang-format on
example/ck_tile/03_gemm/instances/gemm_universal_mem_fp8_fp8_fp8_mk_kn_mn.cpp
0 → 100644
View file @
da8e50dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
// clang-format off
template
float
gemm_
<
gemm_traits_
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
ck_tile
::
fp8_t
,
Row
,
Row
,
Row
,
128
,
128
,
64
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
// clang-format on
example/ck_tile/03_gemm/instances/gemm_universal_mem_fp8_fp8_fp8_mk_nk_mn.cpp
0 → 100644
View file @
da8e50dd
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
// clang-format off
template
float
gemm_
<
gemm_traits_
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
ck_tile
::
fp8_t
,
Row
,
Col
,
Row
,
128
,
128
,
64
,
2
,
2
,
1
,
32
,
32
,
16
,
false
,
false
,
false
>
>
(
const
A
&
,
const
S
&
);
// clang-format on
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment