Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b3054fea
Commit
b3054fea
authored
Jan 22, 2025
by
Adam Osewski
Browse files
Merge branch 'develop' into aosewski/ck_tile_gemm_policy
parents
7cbc1492
052a7265
Changes
54
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
260 additions
and
233 deletions
+260
-233
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n3072_instance.cpp
...thquant/instances/moe_smoothquant_fp16_n3072_instance.cpp
+8
-4
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_instance.cpp
...thquant/instances/moe_smoothquant_fp16_n4096_instance.cpp
+8
-4
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_tp_instance.cpp
...uant/instances/moe_smoothquant_fp16_n4096_tp_instance.cpp
+8
-4
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n512_instance.cpp
...othquant/instances/moe_smoothquant_fp16_n512_instance.cpp
+9
-4
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n64_n128_instance.cpp
...uant/instances/moe_smoothquant_fp16_n64_n128_instance.cpp
+7
-3
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n768_instance.cpp
...othquant/instances/moe_smoothquant_fp16_n768_instance.cpp
+7
-3
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fwd_api.cpp
.../14_moe_smoothquant/instances/moe_smoothquant_fwd_api.cpp
+55
-45
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp
...smoothquant/instances/moe_smoothquant_instance_common.hpp
+11
-8
example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp
example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp
+22
-11
example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp
example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp
+10
-20
example/ck_tile/14_moe_smoothquant/script/smoke_test.sh
example/ck_tile/14_moe_smoothquant/script/smoke_test.sh
+27
-25
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+4
-4
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
+1
-2
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
+4
-4
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
+10
-10
example/ck_tile/17_grouped_gemm/utils.hpp
example/ck_tile/17_grouped_gemm/utils.hpp
+0
-38
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+0
-1
include/ck_tile/core/arch/arch.hpp
include/ck_tile/core/arch/arch.hpp
+51
-6
include/ck_tile/core/utility/amd_address_space.hpp
include/ck_tile/core/utility/amd_address_space.hpp
+0
-37
include/ck_tile/core/utility/type_traits.hpp
include/ck_tile/core/utility/type_traits.hpp
+18
-0
No files found.
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n3072_instance.cpp
View file @
b3054fea
...
@@ -6,9 +6,13 @@
...
@@ -6,9 +6,13 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd 2p
// rm rn tm tn vn pd 2p
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
128
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
3
,
1
,
128
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
256
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
3
,
1
,
256
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
1
,
256
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
6
,
1
,
256
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
1
,
128
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
1
,
256
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
6
,
1
,
256
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_instance.cpp
View file @
b3054fea
...
@@ -6,9 +6,13 @@
...
@@ -6,9 +6,13 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd 2p
// rm rn tm tn vn pd 2p
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
256
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n4096_tp_instance.cpp
View file @
b3054fea
...
@@ -6,9 +6,13 @@
...
@@ -6,9 +6,13 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd 2p
// rm rn tm tn vn pd 2p
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
256
,
8
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
256
,
8
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
256
,
4
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
256
,
4
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
256
,
8
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
256
,
4
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n512_instance.cpp
View file @
b3054fea
...
@@ -6,8 +6,13 @@
...
@@ -6,8 +6,13 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd 2p
// rm rn tm tn vn pd 2p
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
1
,
4
,
64
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
2
,
4
,
64
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
4
,
4
,
64
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
4
,
4
,
64
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
8
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
8
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
1
,
4
,
64
,
8
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
4
,
64
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
4
,
4
,
64
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
8
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n64_n128_instance.cpp
View file @
b3054fea
...
@@ -6,7 +6,11 @@
...
@@ -6,7 +6,11 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd 2p
// rm rn tm tn vn pd 2p
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
1
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
1
,
4
,
64
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
1
,
4
,
64
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
2
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
2
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
1
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
1
,
4
,
64
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
2
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fp16_n768_instance.cpp
View file @
b3054fea
...
@@ -6,7 +6,11 @@
...
@@ -6,7 +6,11 @@
// clang-format off
// clang-format off
// rm rn tm tn vn pd 2p
// rm rn tm tn vn pd 2p
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
3
,
4
,
64
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
3
,
4
,
64
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
6
,
4
,
64
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
6
,
4
,
64
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
1
,
12
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
,
1
,
12
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
3
,
4
,
64
,
4
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
6
,
4
,
64
,
2
,
true
,
false
>
>
(
const
S
&
,
A
);
template
float
moe_smoothquant_
<
trait_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
,
1
,
12
,
4
,
64
,
1
,
true
,
false
>
>
(
const
S
&
,
A
);
// clang-format on
// clang-format on
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_fwd_api.cpp
View file @
b3054fea
...
@@ -4,7 +4,8 @@
...
@@ -4,7 +4,8 @@
#include <ck_tile/core.hpp>
#include <ck_tile/core.hpp>
#include "moe_smoothquant.hpp"
#include "moe_smoothquant.hpp"
template
<
typename
DataType_
,
template
<
typename
InType
,
typename
OutType
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
...
@@ -12,7 +13,8 @@ template <typename DataType_,
...
@@ -12,7 +13,8 @@ template <typename DataType_,
ck_tile
::
index_t
Vector_N_
,
// vector size along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kPadN_
,
bool
kTwoPass_
>
bool
kTwoPass_
>
using
trait_
=
moe_smoothquant_traits_
<
DataType_
,
using
trait_
=
moe_smoothquant_traits_
<
InType
,
OutType
,
Repeat_M_
,
Repeat_M_
,
Repeat_N_
,
Repeat_N_
,
ThreadPerBlock_M_
,
ThreadPerBlock_M_
,
...
@@ -21,108 +23,108 @@ using trait_ = moe_smoothquant_traits_<DataType_,
...
@@ -21,108 +23,108 @@ using trait_ = moe_smoothquant_traits_<DataType_,
kPadN_
,
kPadN_
,
kTwoPass_
>
;
kTwoPass_
>
;
template
<
typename
data
_type
>
template
<
typename
in_type
,
typename
out
_type
>
float
moe_smoothquant_dispatch
(
moe_smoothquant_traits
/*t*/
,
float
moe_smoothquant_dispatch
(
moe_smoothquant_traits
/*t*/
,
moe_smoothquant_args
a
,
moe_smoothquant_args
a
,
const
ck_tile
::
stream_config
&
s
)
const
ck_tile
::
stream_config
&
s
)
{
{
float
r
=
-
1
;
float
r
=
-
1
;
// clang-format off
// clang-format off
// rm rn tm tn vn pd 2p
//
rm rn tm tn vn pd 2p
if
(
a
.
hidden_size
<=
64
)
{
if
(
a
.
hidden_size
<=
64
)
{
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
1
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
1
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
}
}
else
if
(
a
.
hidden_size
<=
128
)
{
else
if
(
a
.
hidden_size
<=
128
)
{
if
(
a
.
hidden_size
%
2
==
0
)
if
(
a
.
hidden_size
%
2
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
1
,
4
,
64
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
1
,
4
,
64
,
2
,
true
,
false
>>
(
s
,
a
);
else
else
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
2
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
2
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
}
}
else
if
(
a
.
hidden_size
<=
256
)
{
else
if
(
a
.
hidden_size
<=
256
)
{
if
(
a
.
hidden_size
%
4
==
0
)
if
(
a
.
hidden_size
%
4
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
1
,
4
,
64
,
4
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
1
,
4
,
64
,
4
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
hidden_size
%
2
==
0
)
else
if
(
a
.
hidden_size
%
2
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
2
,
4
,
64
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
2
,
4
,
64
,
2
,
true
,
false
>>
(
s
,
a
);
else
else
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
4
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
4
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
}
}
else
if
(
a
.
hidden_size
<=
512
)
{
else
if
(
a
.
hidden_size
<=
512
)
{
if
(
a
.
hidden_size
%
8
==
0
)
if
(
a
.
hidden_size
%
8
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
1
,
4
,
64
,
8
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
1
,
4
,
64
,
8
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
hidden_size
%
4
==
0
)
else
if
(
a
.
hidden_size
%
4
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
2
,
4
,
64
,
4
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
2
,
4
,
64
,
4
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
hidden_size
%
2
==
0
)
else
if
(
a
.
hidden_size
%
2
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
4
,
4
,
64
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
4
,
4
,
64
,
2
,
true
,
false
>>
(
s
,
a
);
else
else
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
8
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
8
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
}
}
else
if
(
a
.
hidden_size
<=
768
)
{
else
if
(
a
.
hidden_size
<=
768
)
{
if
(
a
.
hidden_size
%
4
==
0
)
if
(
a
.
hidden_size
%
4
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
3
,
4
,
64
,
4
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
3
,
4
,
64
,
4
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
hidden_size
%
2
==
0
)
else
if
(
a
.
hidden_size
%
2
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
6
,
4
,
64
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
6
,
4
,
64
,
2
,
true
,
false
>>
(
s
,
a
);
else
else
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
12
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
12
,
4
,
64
,
1
,
true
,
false
>>
(
s
,
a
);
}
}
else
if
(
a
.
hidden_size
<=
1024
)
{
else
if
(
a
.
hidden_size
<=
1024
)
{
if
(
a
.
hidden_size
%
8
==
0
)
if
(
a
.
hidden_size
%
8
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
1
,
2
,
128
,
8
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
1
,
2
,
128
,
8
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
hidden_size
%
4
==
0
)
else
if
(
a
.
hidden_size
%
4
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
2
,
2
,
128
,
4
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
2
,
2
,
128
,
4
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
hidden_size
%
2
==
0
)
else
if
(
a
.
hidden_size
%
2
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
4
,
2
,
128
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
4
,
2
,
128
,
2
,
true
,
false
>>
(
s
,
a
);
else
else
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
4
,
1
,
256
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
4
,
1
,
256
,
1
,
true
,
false
>>
(
s
,
a
);
}
}
else
if
(
a
.
hidden_size
<=
1536
)
{
else
if
(
a
.
hidden_size
<=
1536
)
{
if
(
a
.
hidden_size
%
8
==
0
)
if
(
a
.
hidden_size
%
8
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
3
,
4
,
64
,
8
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
3
,
4
,
64
,
8
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
hidden_size
%
4
==
0
)
else
if
(
a
.
hidden_size
%
4
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
3
,
2
,
128
,
4
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
3
,
2
,
128
,
4
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
hidden_size
%
2
==
0
)
else
if
(
a
.
hidden_size
%
2
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
3
,
1
,
256
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
3
,
1
,
256
,
2
,
true
,
false
>>
(
s
,
a
);
else
else
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
6
,
1
,
256
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
6
,
1
,
256
,
1
,
true
,
false
>>
(
s
,
a
);
}
}
else
if
(
a
.
hidden_size
<=
2048
)
{
else
if
(
a
.
hidden_size
<=
2048
)
{
if
(
a
.
hidden_size
%
8
==
0
)
if
(
a
.
hidden_size
%
8
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
1
,
1
,
256
,
8
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
1
,
1
,
256
,
8
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
hidden_size
%
4
==
0
)
else
if
(
a
.
hidden_size
%
4
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
2
,
1
,
256
,
4
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
2
,
1
,
256
,
4
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
hidden_size
%
2
==
0
)
else
if
(
a
.
hidden_size
%
2
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
4
,
1
,
256
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
4
,
1
,
256
,
2
,
true
,
false
>>
(
s
,
a
);
else
else
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
8
,
1
,
256
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
8
,
1
,
256
,
1
,
true
,
false
>>
(
s
,
a
);
}
}
else
if
(
a
.
hidden_size
<=
3072
)
{
else
if
(
a
.
hidden_size
<=
3072
)
{
if
(
a
.
hidden_size
%
8
==
0
)
if
(
a
.
hidden_size
%
8
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
3
,
1
,
128
,
8
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
3
,
1
,
128
,
8
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
hidden_size
%
4
==
0
)
else
if
(
a
.
hidden_size
%
4
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
3
,
1
,
256
,
4
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
3
,
1
,
256
,
4
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
hidden_size
%
2
==
0
)
else
if
(
a
.
hidden_size
%
2
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
6
,
1
,
256
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
6
,
1
,
256
,
2
,
true
,
false
>>
(
s
,
a
);
else
else
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
3
,
1
,
1024
,
1
,
true
,
false
>>
(
s
,
a
);
}
}
else
if
(
a
.
hidden_size
<=
4096
)
{
else
if
(
a
.
hidden_size
<=
4096
)
{
if
(
a
.
hidden_size
%
8
==
0
)
if
(
a
.
hidden_size
%
8
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
2
,
1
,
256
,
8
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
2
,
1
,
256
,
8
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
hidden_size
%
4
==
0
)
else
if
(
a
.
hidden_size
%
4
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
4
,
1
,
256
,
4
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
4
,
1
,
256
,
4
,
true
,
false
>>
(
s
,
a
);
else
if
(
a
.
hidden_size
%
2
==
0
)
else
if
(
a
.
hidden_size
%
2
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
false
>>
(
s
,
a
);
else
else
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
false
>>
(
s
,
a
);
}
}
else
if
(
a
.
hidden_size
>
4096
)
{
else
if
(
a
.
hidden_size
>
4096
)
{
if
(
a
.
hidden_size
%
8
==
0
)
if
(
a
.
hidden_size
%
8
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
2
,
1
,
256
,
8
,
true
,
true
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
2
,
1
,
256
,
8
,
true
,
true
>>
(
s
,
a
);
else
if
(
a
.
hidden_size
%
4
==
0
)
else
if
(
a
.
hidden_size
%
4
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
4
,
1
,
256
,
4
,
true
,
true
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
4
,
1
,
256
,
4
,
true
,
true
>>
(
s
,
a
);
else
if
(
a
.
hidden_size
%
2
==
0
)
else
if
(
a
.
hidden_size
%
2
==
0
)
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
2
,
1
,
1024
,
2
,
true
,
true
>>
(
s
,
a
);
else
else
r
=
moe_smoothquant_
<
trait_
<
data
_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
>>
(
s
,
a
);
r
=
moe_smoothquant_
<
trait_
<
in_type
,
out
_type
,
1
,
4
,
1
,
1024
,
1
,
true
,
true
>>
(
s
,
a
);
}
}
return
r
;
return
r
;
// clang-format on
// clang-format on
...
@@ -132,13 +134,21 @@ float moe_smoothquant(moe_smoothquant_traits t,
...
@@ -132,13 +134,21 @@ float moe_smoothquant(moe_smoothquant_traits t,
moe_smoothquant_args
a
,
moe_smoothquant_args
a
,
const
ck_tile
::
stream_config
&
s
)
const
ck_tile
::
stream_config
&
s
)
{
{
if
(
t
.
data
_type
.
compare
(
"fp16"
)
==
0
)
if
(
t
.
in
_type
.
compare
(
"fp16"
)
==
0
&&
t
.
out_type
==
"int8"
)
{
{
return
moe_smoothquant_dispatch
<
ck_tile
::
fp16_t
>
(
t
,
a
,
s
);
return
moe_smoothquant_dispatch
<
ck_tile
::
fp16_t
,
ck_tile
::
int8_t
>
(
t
,
a
,
s
);
}
}
else
if
(
t
.
data
_type
.
compare
(
"
b
f16"
)
==
0
)
else
if
(
t
.
in
_type
.
compare
(
"f
p
16"
)
==
0
&&
t
.
out_type
==
"fp8"
)
{
{
return
moe_smoothquant_dispatch
<
ck_tile
::
bf16_t
>
(
t
,
a
,
s
);
return
moe_smoothquant_dispatch
<
ck_tile
::
fp16_t
,
ck_tile
::
fp8_t
>
(
t
,
a
,
s
);
}
else
if
(
t
.
in_type
.
compare
(
"bf16"
)
==
0
&&
t
.
out_type
==
"int8"
)
{
return
moe_smoothquant_dispatch
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
>
(
t
,
a
,
s
);
}
else
if
(
t
.
in_type
.
compare
(
"bf16"
)
==
0
&&
t
.
out_type
==
"fp8"
)
{
return
moe_smoothquant_dispatch
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
>
(
t
,
a
,
s
);
}
}
else
else
throw
std
::
runtime_error
(
"Without supported instances!"
);
throw
std
::
runtime_error
(
"Without supported instances!"
);
...
...
example/ck_tile/14_moe_smoothquant/instances/moe_smoothquant_instance_common.hpp
View file @
b3054fea
...
@@ -11,7 +11,8 @@
...
@@ -11,7 +11,8 @@
using
S
=
ck_tile
::
stream_config
;
using
S
=
ck_tile
::
stream_config
;
using
A
=
moe_smoothquant_args
;
using
A
=
moe_smoothquant_args
;
template
<
typename
DataType_
,
template
<
typename
InputType_
,
typename
OutputType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
...
@@ -19,7 +20,8 @@ template <typename DataType_,
...
@@ -19,7 +20,8 @@ template <typename DataType_,
ck_tile
::
index_t
Vector_N_
,
// vector size along N
ck_tile
::
index_t
Vector_N_
,
// vector size along N
bool
kPadN_
,
bool
kPadN_
,
bool
kTwoPass_
>
bool
kTwoPass_
>
using
trait_
=
moe_smoothquant_traits_
<
DataType_
,
using
trait_
=
moe_smoothquant_traits_
<
InputType_
,
OutputType_
,
Repeat_M_
,
Repeat_M_
,
Repeat_N_
,
Repeat_N_
,
ThreadPerBlock_M_
,
ThreadPerBlock_M_
,
...
@@ -31,14 +33,15 @@ using trait_ = moe_smoothquant_traits_<DataType_,
...
@@ -31,14 +33,15 @@ using trait_ = moe_smoothquant_traits_<DataType_,
template
<
typename
Traits_
>
template
<
typename
Traits_
>
float
moe_smoothquant_
(
const
S
&
s
,
A
a
)
float
moe_smoothquant_
(
const
S
&
s
,
A
a
)
{
{
using
DataType
=
typename
Traits_
::
DataType
;
using
InputType
=
typename
Traits_
::
InputType
;
using
OutputType
=
typename
Traits_
::
OutputType
;
using
PipelineProblem
=
ck_tile
::
SmoothquantPipelineProblem
<
using
PipelineProblem
=
ck_tile
::
SmoothquantPipelineProblem
<
typename
MoeSmoothquantTypeConfig
<
Data
Type
>::
XDataType
,
typename
MoeSmoothquantTypeConfig
<
InputType
,
Output
Type
>::
XDataType
,
typename
MoeSmoothquantTypeConfig
<
Data
Type
>::
SmoothScaleDataType
,
typename
MoeSmoothquantTypeConfig
<
InputType
,
Output
Type
>::
SmoothScaleDataType
,
typename
MoeSmoothquantTypeConfig
<
Data
Type
>::
ComputeDataType
,
typename
MoeSmoothquantTypeConfig
<
InputType
,
Output
Type
>::
ComputeDataType
,
typename
MoeSmoothquantTypeConfig
<
Data
Type
>::
YScaleDataType
,
typename
MoeSmoothquantTypeConfig
<
InputType
,
Output
Type
>::
YScaleDataType
,
typename
MoeSmoothquantTypeConfig
<
Data
Type
>::
QYDataType
,
typename
MoeSmoothquantTypeConfig
<
InputType
,
Output
Type
>::
QYDataType
,
typename
Traits_
::
Shape
,
typename
Traits_
::
Shape
,
Traits_
::
kPadN
,
Traits_
::
kPadN
,
Traits_
::
kTwoPass
>
;
Traits_
::
kTwoPass
>
;
...
...
example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp
View file @
b3054fea
...
@@ -63,7 +63,8 @@ auto create_args(int argc, char* argv[])
...
@@ -63,7 +63,8 @@ auto create_args(int argc, char* argv[])
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to hidden_size"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to hidden_size"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"v"
,
"1"
,
"cpu validation or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"kname"
,
"1"
,
"print kernel name or not"
)
.
insert
(
"prec"
,
"fp16"
,
"precision"
)
.
insert
(
"prec_i"
,
"fp16"
,
"input precision, fp16/bf16"
)
.
insert
(
"prec_o"
,
"int8"
,
"precision, int8/fp8"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
...
@@ -71,7 +72,7 @@ auto create_args(int argc, char* argv[])
...
@@ -71,7 +72,7 @@ auto create_args(int argc, char* argv[])
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
}
template
<
typename
Data
Type
>
template
<
typename
InputType
,
typename
Output
Type
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
{
ck_tile
::
index_t
tokens
=
arg_parser
.
get_int
(
"t"
);
ck_tile
::
index_t
tokens
=
arg_parser
.
get_int
(
"t"
);
...
@@ -81,7 +82,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -81,7 +82,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride
=
hidden_size
;
stride
=
hidden_size
;
ck_tile
::
index_t
experts
=
arg_parser
.
get_int
(
"e"
);
ck_tile
::
index_t
experts
=
arg_parser
.
get_int
(
"e"
);
ck_tile
::
index_t
topk
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
topk
=
arg_parser
.
get_int
(
"k"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
...
@@ -89,7 +91,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -89,7 +91,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
assert
(
stride
>=
hidden_size
);
assert
(
stride
>=
hidden_size
);
using
TypeConfig
=
MoeSmoothquantTypeConfig
<
Data
Type
>
;
using
TypeConfig
=
MoeSmoothquantTypeConfig
<
InputType
,
Output
Type
>
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
XDataType
=
typename
TypeConfig
::
XDataType
;
using
SmoothScaleDataType
=
typename
TypeConfig
::
SmoothScaleDataType
;
using
SmoothScaleDataType
=
typename
TypeConfig
::
SmoothScaleDataType
;
...
@@ -122,11 +124,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -122,11 +124,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
smscale_buf
.
ToDevice
(
smscale_host
.
data
());
smscale_buf
.
ToDevice
(
smscale_host
.
data
());
topk_ids_buf
.
ToDevice
(
topk_ids_host
.
data
());
topk_ids_buf
.
ToDevice
(
topk_ids_host
.
data
());
std
::
cout
<<
"["
<<
data_type
<<
"]"
std
::
cout
<<
"["
<<
prec_i
<<
"-"
<<
prec_o
<<
"]"
<<
" tokens:"
<<
tokens
<<
", hidden_size:"
<<
hidden_size
<<
", stride:"
<<
stride
<<
" tokens:"
<<
tokens
<<
", hidden_size:"
<<
hidden_size
<<
", stride:"
<<
stride
<<
", experts:"
<<
experts
<<
", topk:"
<<
topk
<<
std
::
flush
;
<<
", experts:"
<<
experts
<<
", topk:"
<<
topk
<<
std
::
flush
;
moe_smoothquant_traits
traits
{
data_type
};
moe_smoothquant_traits
traits
{
prec_i
,
prec_o
};
moe_smoothquant_args
args
{
x_buf
.
GetDeviceBuffer
(),
moe_smoothquant_args
args
{
x_buf
.
GetDeviceBuffer
(),
smscale_buf
.
GetDeviceBuffer
(),
smscale_buf
.
GetDeviceBuffer
(),
...
@@ -251,14 +253,23 @@ int main(int argc, char* argv[])
...
@@ -251,14 +253,23 @@ int main(int argc, char* argv[])
if
(
!
result
)
if
(
!
result
)
return
-
1
;
return
-
1
;
const
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
const
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
if
(
data_type
==
"fp16"
)
const
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
if
(
prec_i
==
"fp16"
&&
prec_o
==
"int8"
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
int8_t
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp8"
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
fp8_t
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"int8"
)
{
{
return
run
<
ck_tile
::
half
_t
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
int8
_t
>
(
arg_parser
)
?
0
:
-
2
;
}
}
else
if
(
data_type
==
"bf16
"
)
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"fp8
"
)
{
{
return
run
<
ck_tile
::
bf16_t
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
>
(
arg_parser
)
?
0
:
-
2
;
}
}
return
-
3
;
return
-
3
;
...
...
example/ck_tile/14_moe_smoothquant/moe_smoothquant.hpp
View file @
b3054fea
...
@@ -8,26 +8,13 @@
...
@@ -8,26 +8,13 @@
#include "ck_tile/ops/smoothquant.hpp"
#include "ck_tile/ops/smoothquant.hpp"
#include <string>
#include <string>
template
<
typename
DataType
>
template
<
typename
InputType
,
typename
OutputType
>
struct
MoeSmoothquantTypeConfig
;
struct
MoeSmoothquantTypeConfig
template
<
>
struct
MoeSmoothquantTypeConfig
<
ck_tile
::
half_t
>
{
using
XDataType
=
ck_tile
::
half_t
;
using
SmoothScaleDataType
=
float
;
using
YScaleDataType
=
float
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
ComputeDataType
=
float
;
};
template
<
>
struct
MoeSmoothquantTypeConfig
<
ck_tile
::
bf16_t
>
{
{
using
XDataType
=
ck_tile
::
bf16_t
;
using
XDataType
=
InputType
;
using
SmoothScaleDataType
=
float
;
using
SmoothScaleDataType
=
float
;
using
YScaleDataType
=
float
;
using
YScaleDataType
=
float
;
using
QYDataType
=
ck_tile
::
int8_t
;
using
QYDataType
=
OutputType
;
using
ComputeDataType
=
float
;
using
ComputeDataType
=
float
;
};
};
...
@@ -37,7 +24,8 @@ struct moe_smoothquant_args : public ck_tile::MoeSmoothquantHostArgs
...
@@ -37,7 +24,8 @@ struct moe_smoothquant_args : public ck_tile::MoeSmoothquantHostArgs
};
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
typename
DataType_
,
template
<
typename
InputType_
,
typename
OutputType_
,
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_M_
,
// each thread repeat along M
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
Repeat_N_
,
// each thread repeat along N
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
ck_tile
::
index_t
ThreadPerBlock_M_
,
// num threads along M
...
@@ -47,7 +35,8 @@ template <typename DataType_,
...
@@ -47,7 +35,8 @@ template <typename DataType_,
bool
kTwoPass_
>
bool
kTwoPass_
>
struct
moe_smoothquant_traits_
struct
moe_smoothquant_traits_
{
{
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
using
InputType
=
ck_tile
::
remove_cvref_t
<
InputType_
>
;
using
OutputType
=
ck_tile
::
remove_cvref_t
<
OutputType_
>
;
static
constexpr
bool
is_warp_per_row
=
ThreadPerBlock_N_
<=
warpSize
;
static
constexpr
bool
is_warp_per_row
=
ThreadPerBlock_N_
<=
warpSize
;
static_assert
((
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
%
warpSize
==
0
);
static_assert
((
ThreadPerBlock_M_
*
ThreadPerBlock_N_
)
%
warpSize
==
0
);
...
@@ -108,7 +97,8 @@ float moe_smoothquant_(const ck_tile::stream_config& s, moe_smoothquant_args a);
...
@@ -108,7 +97,8 @@ float moe_smoothquant_(const ck_tile::stream_config& s, moe_smoothquant_args a);
// This is the public API, will be generated by script
// This is the public API, will be generated by script
struct
moe_smoothquant_traits
struct
moe_smoothquant_traits
{
{
std
::
string
data_type
;
std
::
string
in_type
;
// input type
std
::
string
out_type
;
// output type
};
};
float
moe_smoothquant
(
moe_smoothquant_traits
,
moe_smoothquant_args
,
const
ck_tile
::
stream_config
&
);
float
moe_smoothquant
(
moe_smoothquant_traits
,
moe_smoothquant_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/14_moe_smoothquant/script/smoke_test.sh
View file @
b3054fea
...
@@ -2,29 +2,31 @@
...
@@ -2,29 +2,31 @@
EXE
=
build/bin/tile_example_moe_smoothquant
EXE
=
build/bin/tile_example_moe_smoothquant
for
pr_i
in
"fp16"
"bf16"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
$EXE
-prec
=
$pr_i
-t
=
99
-h
=
13
for
pr_o
in
"int8"
"fp8"
;
do
$EXE
-prec
=
$pr_i
-t
=
17
-h
=
16
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
99
-h
=
13
$EXE
-prec
=
$pr_i
-t
=
1
-h
=
100
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
17
-h
=
16
$EXE
-prec
=
$pr_i
-t
=
4
-h
=
128
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
1
-h
=
100
$EXE
-prec
=
$pr_i
-t
=
80
-h
=
127
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
4
-h
=
128
$EXE
-prec
=
$pr_i
-t
=
22
-h
=
255
-stride
=
256
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
80
-h
=
127
$EXE
-prec
=
$pr_i
-t
=
7
-h
=
599
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
22
-h
=
255
-stride
=
256
$EXE
-prec
=
$pr_i
-t
=
19
-h
=
512
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
7
-h
=
599
$EXE
-prec
=
$pr_i
-t
=
33
-h
=
313
-stride
=
1000
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
19
-h
=
512
$EXE
-prec
=
$pr_i
-t
=
11
-h
=
510
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
33
-h
=
313
-stride
=
1000
$EXE
-prec
=
$pr_i
-t
=
171
-h
=
676
-stride
=
818
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
11
-h
=
510
$EXE
-prec
=
$pr_i
-t
=
91
-h
=
636
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
171
-h
=
676
-stride
=
818
$EXE
-prec
=
$pr_i
-t
=
12
-h
=
768
-stride
=
800
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
91
-h
=
636
$EXE
-prec
=
$pr_i
-t
=
100
-h
=
766
-stride
=
812
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
12
-h
=
768
-stride
=
800
$EXE
-prec
=
$pr_i
-t
=
31
-h
=
1024
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
100
-h
=
766
-stride
=
812
$EXE
-prec
=
$pr_i
-t
=
64
-h
=
1000
-stride
=
1004
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
31
-h
=
1024
$EXE
-prec
=
$pr_i
-t
=
8
-h
=
1501
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
64
-h
=
1000
-stride
=
1004
$EXE
-prec
=
$pr_i
-t
=
3
-h
=
1826
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
8
-h
=
1501
$EXE
-prec
=
$pr_i
-t
=
5
-h
=
2040
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
3
-h
=
1826
$EXE
-prec
=
$pr_i
-t
=
7
-h
=
2734
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
5
-h
=
2040
$EXE
-prec
=
$pr_i
-t
=
1
-h
=
3182
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
7
-h
=
2734
$EXE
-prec
=
$pr_i
-t
=
9
-h
=
4096
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
1
-h
=
3182
$EXE
-prec
=
$pr_i
-t
=
3
-h
=
8192
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
9
-h
=
4096
$EXE
-prec
=
$pr_i
-t
=
1
-h
=
10547
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
3
-h
=
8192
$EXE
-prec
=
$pr_i
-t
=
3
-h
=
17134
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
1
-h
=
10547
$EXE
-prec_i
=
$pr_i
-prec_o
=
$pr_o
-t
=
3
-h
=
17134
done
done
done
example/ck_tile/16_batched_gemm/batched_gemm.cpp
View file @
b3054fea
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
...
@@ -51,7 +51,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
...
@@ -51,7 +51,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTilePartitioner
<
CodegenGemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile
2D
Partitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
CShuffleEpilogue
,
...
@@ -63,8 +63,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
...
@@ -63,8 +63,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
kOutputRank
,
kOutputRank
,
1
,
1
,
0
,
0
,
TilePartitioner
::
k
M
,
TilePartitioner
::
M
PerBlock
,
TilePartitioner
::
k
N
>>
,
TilePartitioner
::
N
PerBlock
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
...
...
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
View file @
b3054fea
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "grouped_gemm.hpp"
#include "grouped_gemm.hpp"
#include "utils.hpp"
namespace
{
namespace
{
...
@@ -102,7 +101,7 @@ using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
...
@@ -102,7 +101,7 @@ using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
GemmEpilogue
<
CLayout
>>
;
GemmEpilogue
<
CLayout
>>
;
};
// namespace
};
// namespace
std
::
size_t
G
et
W
orkspace
S
ize
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
)
std
::
size_t
g
et
_w
orkspace
_s
ize
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
)
{
{
return
::
Kernel
<
std
::
nullptr_t
,
std
::
nullptr_t
,
std
::
nullptr_t
>::
GetWorkSpaceSize
(
gemm_descs
);
return
::
Kernel
<
std
::
nullptr_t
,
std
::
nullptr_t
,
std
::
nullptr_t
>::
GetWorkSpaceSize
(
gemm_descs
);
}
}
...
...
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
View file @
b3054fea
...
@@ -52,8 +52,8 @@ auto create_args(int argc, char* argv[])
...
@@ -52,8 +52,8 @@ auto create_args(int argc, char* argv[])
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
}
std
::
size_t
G
et
W
orkspace
S
ize
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
);
std
::
size_t
g
et
_w
orkspace
_s
ize
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
);
float
grouped_gemm
_calc
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
,
float
grouped_gemm
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
,
const
ck_tile
::
stream_config
&
s
,
const
ck_tile
::
stream_config
&
s
,
void
*
p_workspace_
);
void
*
p_workspace_
);
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
View file @
b3054fea
...
@@ -31,7 +31,7 @@ float invoke_gemm(int n_warmup,
...
@@ -31,7 +31,7 @@ float invoke_gemm(int n_warmup,
{
{
ck_tile
::
DeviceMem
gemm_workspace
;
ck_tile
::
DeviceMem
gemm_workspace
;
gemm_workspace
.
Realloc
(
G
et
W
orkspace
S
ize
(
args
));
gemm_workspace
.
Realloc
(
g
et
_w
orkspace
_s
ize
(
args
));
float
ave_time
=
grouped_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
float
ave_time
=
grouped_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
args
,
args
,
...
@@ -128,16 +128,16 @@ int run_grouped_gemm_example_with_layouts(int argc,
...
@@ -128,16 +128,16 @@ int run_grouped_gemm_example_with_layouts(int argc,
const
ck_tile
::
index_t
N
=
Ns
[
i
];
const
ck_tile
::
index_t
N
=
Ns
[
i
];
const
ck_tile
::
index_t
K
=
Ks
[
i
];
const
ck_tile
::
index_t
K
=
Ks
[
i
];
stride_As
[
i
]
=
f_
get_default_stride
(
M
,
N
,
stride_As
[
i
],
a_layout
);
stride_As
[
i
]
=
ck_tile
::
get_default_stride
(
M
,
N
,
stride_As
[
i
],
a_layout
);
stride_Bs
[
i
]
=
f_
get_default_stride
(
K
,
N
,
stride_Bs
[
i
],
b_layout
);
stride_Bs
[
i
]
=
ck_tile
::
get_default_stride
(
K
,
N
,
stride_Bs
[
i
],
b_layout
);
stride_Cs
[
i
]
=
f_
get_default_stride
(
M
,
N
,
stride_Cs
[
i
],
CLayout
{});
stride_Cs
[
i
]
=
ck_tile
::
get_default_stride
(
M
,
N
,
stride_Cs
[
i
],
CLayout
{});
a_m_k_tensors
.
push_back
(
a_m_k_tensors
.
push_back
(
ck_tile
::
HostTensor
<
ADataType
>
(
ck_tile
::
HostTensor
<
ADataType
>
(
f_
host_tensor_descriptor
(
M
,
K
,
stride_As
[
i
],
a_layout
)));
ck_tile
::
host_tensor_descriptor
(
M
,
K
,
stride_As
[
i
],
a_layout
)));
b_k_n_tensors
.
push_back
(
b_k_n_tensors
.
push_back
(
ck_tile
::
HostTensor
<
BDataType
>
(
ck_tile
::
HostTensor
<
BDataType
>
(
f_
host_tensor_descriptor
(
K
,
N
,
stride_Bs
[
i
],
b_layout
)));
ck_tile
::
host_tensor_descriptor
(
K
,
N
,
stride_Bs
[
i
],
b_layout
)));
c_m_n_tensors
.
push_back
(
ck_tile
::
HostTensor
<
CDataType
>
(
c_m_n_tensors
.
push_back
(
ck_tile
::
HostTensor
<
CDataType
>
(
f_
host_tensor_descriptor
(
M
,
N
,
stride_Cs
[
i
],
CLayout
{})));
ck_tile
::
host_tensor_descriptor
(
M
,
N
,
stride_Cs
[
i
],
CLayout
{})));
std
::
cout
<<
"gemm["
<<
i
<<
"]"
std
::
cout
<<
"gemm["
<<
i
<<
"]"
<<
" a_m_k: "
<<
a_m_k_tensors
[
i
]
.
mDesc
<<
" b_k_n: "
<<
b_k_n_tensors
[
i
]
.
mDesc
<<
" a_m_k: "
<<
a_m_k_tensors
[
i
]
.
mDesc
<<
" b_k_n: "
<<
b_k_n_tensors
[
i
]
.
mDesc
...
@@ -178,7 +178,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
...
@@ -178,7 +178,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
{
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
ck_tile
::
HostTensor
<
CDataType
>
c_m_n_host_ref
(
f_
host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
stride_Cs
[
i
],
CLayout
{}));
ck_tile
::
host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
stride_Cs
[
i
],
CLayout
{}));
c_m_n_host_ref
.
SetZero
();
c_m_n_host_ref
.
SetZero
();
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
ck_tile
::
reference_gemm
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
>
(
a_m_k_tensors
[
i
],
b_k_n_tensors
[
i
],
c_m_n_host_ref
);
a_m_k_tensors
[
i
],
b_k_n_tensors
[
i
],
c_m_n_host_ref
);
...
...
example/ck_tile/17_grouped_gemm/utils.hpp
deleted
100644 → 0
View file @
7cbc1492
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template
<
typename
TLayout
>
constexpr
auto
f_host_tensor_descriptor
(
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
TLayout
layout
)
{
using
namespace
ck_tile
::
literals
;
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
ck_tile
::
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
}
template
<
typename
TLayout
>
constexpr
auto
f_get_default_stride
(
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
TLayout
layout
)
{
if
(
stride
==
0
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
}
include/ck_tile/core.hpp
View file @
b3054fea
...
@@ -56,7 +56,6 @@
...
@@ -56,7 +56,6 @@
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/transpose_tile.hpp"
#include "ck_tile/core/tensor/transpose_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
...
...
include/ck_tile/core/arch/arch.hpp
View file @
b3054fea
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -12,18 +12,37 @@
...
@@ -12,18 +12,37 @@
namespace
ck_tile
{
namespace
ck_tile
{
enum
struct
address_space_enum
template
<
typename
,
bool
>
struct
safe_underlying_type
;
template
<
typename
T
>
struct
safe_underlying_type
<
T
,
true
>
{
using
type
=
std
::
underlying_type_t
<
T
>
;
};
template
<
typename
T
>
struct
safe_underlying_type
<
T
,
false
>
{
using
type
=
void
;
};
template
<
typename
T
>
using
safe_underlying_type_t
=
typename
safe_underlying_type
<
T
,
std
::
is_enum
<
T
>::
value
>::
type
;
enum
struct
address_space_enum
:
std
::
uint16_t
{
{
generic
,
generic
=
0
,
global
,
global
,
lds
,
lds
,
sgpr
,
sgpr
,
vgpr
,
constant
,
vgpr
};
};
enum
struct
memory_operation_enum
enum
struct
memory_operation_enum
:
std
::
uint16_t
{
{
set
,
set
=
0
,
atomic_add
,
atomic_add
,
atomic_max
,
atomic_max
,
add
add
...
@@ -109,4 +128,30 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0)
...
@@ -109,4 +128,30 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0)
#endif
#endif
}
}
#define CK_CONSTANT_ADDRESS_SPACE \
__attribute__((address_space( \
static_cast<safe_underlying_type_t<address_space_enum>>(address_space_enum::constant))))
template
<
typename
T
>
__device__
T
*
cast_pointer_to_generic_address_space
(
T
CK_CONSTANT_ADDRESS_SPACE
*
p
)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return
(
T
*
)(
p
);
// NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template
<
typename
T
>
__host__
__device__
T
CK_CONSTANT_ADDRESS_SPACE
*
cast_pointer_to_constant_address_space
(
T
*
p
)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return
(
T
CK_CONSTANT_ADDRESS_SPACE
*
)
p
;
// NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/utility/amd_address_space.hpp
deleted
100644 → 0
View file @
7cbc1492
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
namespace
ck_tile
{
#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4)))
template
<
typename
T
>
__device__
T
*
cast_pointer_to_generic_address_space
(
T
CK_CONSTANT_ADDRESS_SPACE
*
p
)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return
(
T
*
)
p
;
// NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template
<
typename
T
>
__host__
__device__
T
CK_CONSTANT_ADDRESS_SPACE
*
cast_pointer_to_constant_address_space
(
T
*
p
)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return
(
T
CK_CONSTANT_ADDRESS_SPACE
*
)
p
;
// NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
}
// namespace ck_tile
include/ck_tile/core/utility/type_traits.hpp
View file @
b3054fea
...
@@ -109,4 +109,22 @@ CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
...
@@ -109,4 +109,22 @@ CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
#pragma clang diagnostic pop
#pragma clang diagnostic pop
}
}
template
<
typename
CompareTo
,
typename
...
Rest
>
struct
is_any_of
:
std
::
false_type
{
};
template
<
typename
CompareTo
,
typename
FirstType
>
struct
is_any_of
<
CompareTo
,
FirstType
>
:
std
::
is_same
<
CompareTo
,
FirstType
>
{
};
template
<
typename
CompareTo
,
typename
FirstType
,
typename
...
Rest
>
struct
is_any_of
<
CompareTo
,
FirstType
,
Rest
...
>
:
std
::
integral_constant
<
bool
,
std
::
is_same
<
CompareTo
,
FirstType
>::
value
||
is_any_of
<
CompareTo
,
Rest
...
>::
value
>
{
};
}
// namespace ck_tile
}
// namespace ck_tile
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment