Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ec959387
Unverified
Commit
ec959387
authored
Feb 13, 2025
by
rocking
Committed by
GitHub
Feb 13, 2025
Browse files
Merge branch 'develop' into ck_tile/fmha_receipt_aiter
parents
c1e2fef7
0e5e29c4
Changes
393
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
573 additions
and
137 deletions
+573
-137
example/ck_tile/03_gemm/script/smoke_test_basic.sh
example/ck_tile/03_gemm/script/smoke_test_basic.sh
+18
-17
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
+18
-17
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+198
-12
example/ck_tile/10_rmsnorm2d/CMakeLists.txt
example/ck_tile/10_rmsnorm2d/CMakeLists.txt
+1
-1
example/ck_tile/10_rmsnorm2d/generate.py
example/ck_tile/10_rmsnorm2d/generate.py
+5
-3
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
+19
-3
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
+2
-2
example/ck_tile/13_moe_sorting/moe_sorting.cpp
example/ck_tile/13_moe_sorting/moe_sorting.cpp
+57
-6
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
+82
-0
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
+2
-1
example/ck_tile/13_moe_sorting/script/smoke_test.sh
example/ck_tile/13_moe_sorting/script/smoke_test.sh
+8
-0
example/ck_tile/15_fused_moe/README.md
example/ck_tile/15_fused_moe/README.md
+1
-1
example/ck_tile/15_fused_moe/fused_moe.hpp
example/ck_tile/15_fused_moe/fused_moe.hpp
+11
-8
example/ck_tile/15_fused_moe/fused_moesorting.hpp
example/ck_tile/15_fused_moe/fused_moesorting.hpp
+2
-1
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
+2
-1
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
...e/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
+82
-0
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+35
-25
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+23
-32
example/ck_tile/16_batched_gemm/batched_gemm.hpp
example/ck_tile/16_batched_gemm/batched_gemm.hpp
+1
-1
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+6
-6
No files found.
example/ck_tile/03_gemm/script/smoke_test_basic.sh
View file @
ec959387
...
@@ -7,22 +7,20 @@ export CK_REPEAT=1
...
@@ -7,22 +7,20 @@ export CK_REPEAT=1
COMMON_ARGS
=
'-v=2 -warmup=0 -repeat=1'
COMMON_ARGS
=
'-v=2 -warmup=0 -repeat=1'
run_fp16_tests
()
{
run_tests
()
{
for
batch
in
1 2
;
do
for
m
in
128 1024
;
do
for
m
in
128 1024
;
do
for
n
in
128 2048
;
do
for
n
in
128 2048
;
do
for
k
in
64 128
;
do
for
k
in
32 64
;
do
$EXE
-m
=
$m
-n
=
$n
-k
=
$k
-stride_a
=
0
-stride_b
=
0
-stride_c
=
0
-prec
=
$1
$COMMON_ARGS
$EXE
-b
=
$batch
-m
=
$m
-n
=
$n
-k
=
$k
-stride_a
=
0
-stride_b
=
0
-stride_c
=
0
-e
=
1e-5
-prec
=
fp16
$COMMON_ARGS
if
[
$?
-eq
0
]
;
then
if
[
$?
-eq
0
]
;
then
echo
"Success: Test with m=
$m
, n=
$n
, k=
$k
executed successfully."
echo
"Success: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
executed successfully."
else
else
echo
"Error: Test with m=
$m
, n=
$n
, k=
$k
failed to execute properly."
echo
"Error: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# Optionally, exit or break if you need to halt further execution
# exit 1
# exit 1
fi
fi
done
done
done
done
done
done
done
...
@@ -30,6 +28,9 @@ run_fp16_tests() {
...
@@ -30,6 +28,9 @@ run_fp16_tests() {
set
-x
set
-x
run_fp16_tests
run_tests
"fp16"
run_tests
"bf16"
run_tests
"fp8"
run_tests
"bf8"
set
+x
set
+x
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
View file @
ec959387
...
@@ -7,22 +7,20 @@ export CK_REPEAT=1
...
@@ -7,22 +7,20 @@ export CK_REPEAT=1
COMMON_ARGS
=
'-v=2 -warmup=0 -repeat=1'
COMMON_ARGS
=
'-v=2 -warmup=0 -repeat=1'
run_fp16_tests
()
{
run_tests
()
{
for
batch
in
1 2
;
do
for
m
in
512 1024
;
do
for
m
in
128 1024
;
do
for
n
in
512 2048
;
do
for
n
in
128 2048
;
do
for
k
in
512 1024
;
do
for
k
in
32 64
;
do
$EXE
-m
=
$m
-n
=
$n
-k
=
$k
-stride_a
=
0
-stride_b
=
0
-stride_c
=
0
-prec
=
$1
$COMMON_ARGS
$EXE
-b
=
$batch
-m
=
$m
-n
=
$n
-k
=
$k
-stride_a
=
0
-stride_b
=
0
-stride_c
=
0
-e
=
1e-5
-prec
=
fp16
$COMMON_ARGS
if
[
$?
-eq
0
]
;
then
if
[
$?
-eq
0
]
;
then
echo
"Success: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
executed successfully."
echo
"Success: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
executed successfully."
else
else
echo
"Error: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
failed to execute properly."
echo
"Error: Test with batch=
$batch
, m=
$m
, n=
$n
, k=
$k
failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# Optionally, exit or break if you need to halt further execution
# exit 1
# exit 1
fi
fi
done
done
done
done
done
done
done
...
@@ -30,6 +28,9 @@ run_fp16_tests() {
...
@@ -30,6 +28,9 @@ run_fp16_tests() {
set
-x
set
-x
run_fp16_tests
run_tests
"fp16"
run_tests
"bf16"
run_tests
"fp8"
run_tests
"bf8"
set
+x
set
+x
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
ec959387
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
...
@@ -12,7 +12,13 @@
...
@@ -12,7 +12,13 @@
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
#include "gemm_basic.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
...
@@ -29,10 +35,28 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -29,10 +35,28 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
constexpr
bool
DoubleSmemBuffer
=
false
;
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler
// Compute friendly for Intrawave scheduler
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
K_Tile
=
64
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
N_Warp
=
2
;
constexpr
ck_tile
::
index_t
K_Warp
=
1
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
constexpr
bool
DoubleSmemBuffer
=
false
;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr
ck_tile
::
index_t
M_Tile
=
256
;
constexpr
ck_tile
::
index_t
N_Tile
=
256
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
constexpr
ck_tile
::
index_t
M_Warp
=
2
;
...
@@ -42,13 +66,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -42,13 +66,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
constexpr
bool
DoubleSmemBuffer
=
true
;
#endif
#endif
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
bool
TransposeC
=
false
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
ck_tile
::
index_t
TileParitionerGroupNum
=
8
;
constexpr
ck_tile
::
index_t
TileParitionerM01
=
4
;
// ===============================================
// ===============================================
...
@@ -56,13 +86,18 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -56,13 +86,18 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmSpatiallyLocalTilePartitioner
<
GemmShape
,
TileParitionerGroupNum
,
TileParitionerM01
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
DoubleSmemBuffer
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
using
GemmPipelineProblem
=
using
GemmPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
...
@@ -85,14 +120,27 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -85,14 +120,27 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
BDataType
,
BDataType
,
AccDataType
,
AccDataType
,
GemmShape
,
GemmShape
,
Traits
,
GemmUniversal
Traits
,
scheduler
,
scheduler
,
has_hot_loop_v
,
has_hot_loop_v
,
tail_number_v
>
;
tail_number_v
>
;
using
GemmPipeline
=
GEMM_PIPELINE
<
UniversalGemmProblem
>
;
using
GemmPipeline
=
GEMM_PIPELINE
<
UniversalGemmProblem
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
GemmPipelineProblem
::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
M_Warp
,
N_Warp
,
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
,
UniversalGemmProblem
::
TransposeC
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
...
@@ -117,6 +165,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -117,6 +165,21 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
if
(
has_hot_loop
)
if
(
has_hot_loop
)
{
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Full
>
{});
}
else
{
std
::
ostringstream
err
;
err
<<
"For compute pipeline tail number should always be Full, but have
\"
"
<<
tail_num
<<
"
\"
which is not supported! PrefetchStages: "
<<
BaseGemmPipeline
::
PrefetchStages
<<
"
\n
File: "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Tail pipeline One to Seven
// Tail pipeline One to Seven
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
if
(
tail_num
==
ck_tile
::
TailNumber
::
One
)
{
{
...
@@ -177,6 +240,18 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -177,6 +240,18 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Seven
>
{});
}
}
}
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Three
>
{});
}
else
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
}
#endif
}
}
else
else
{
{
...
@@ -201,4 +276,115 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -201,4 +276,115 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
#include "run_gemm_example.inc"
#include "run_gemm_example.inc"
int
run_gemm_example
(
int
argc
,
char
*
argv
[])
{
auto
[
result
,
arg_parser
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
{
if
(
data_type
==
"fp16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
half_t
>
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"bf16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf16_t
>
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"fp8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
fp8_t
>
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"bf8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf8_t
>
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data_type!"
);
}
}
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
if
(
data_type
==
"fp16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
half_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf16_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"fp8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
fp8_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf8_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data_type!"
);
}
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"C"
)
{
if
(
data_type
==
"fp16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
half_t
>
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf16_t
>
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"fp8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
fp8_t
>
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf8_t
>
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data_type!"
);
}
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"R"
)
{
if
(
data_type
==
"fp16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
half_t
>
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"bf16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf16_t
>
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"fp8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
fp8_t
>
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"bf8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf8_t
>
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data_type!"
);
}
}
else
{
throw
std
::
runtime_error
(
"Unsupported data layout configuration for A,B and C tensors!"
);
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_example
(
argc
,
argv
);
}
example/ck_tile/10_rmsnorm2d/CMakeLists.txt
View file @
ec959387
...
@@ -33,7 +33,7 @@ target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS})
...
@@ -33,7 +33,7 @@ target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS})
set
(
TILE_RMSNORM2D_FWD_COMPILE_OPTIONS
)
set
(
TILE_RMSNORM2D_FWD_COMPILE_OPTIONS
)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list
(
APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
)
list
(
APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
--offload-compress
)
target_compile_options
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
TILE_RMSNORM2D_FWD_COMPILE_OPTIONS
}
)
target_compile_options
(
${
TILE_RMSNORM2D_FWD
}
PRIVATE
${
TILE_RMSNORM2D_FWD_COMPILE_OPTIONS
}
)
...
...
example/ck_tile/10_rmsnorm2d/generate.py
View file @
ec959387
...
@@ -37,7 +37,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [
...
@@ -37,7 +37,8 @@ FUSED_FUSED_SWEEP_STR_MAP = [
DATA_TYPE_MAP
=
{
'fp32'
:
'float'
,
DATA_TYPE_MAP
=
{
'fp32'
:
'float'
,
'fp16'
:
'ck_tile::fp16_t'
,
'fp16'
:
'ck_tile::fp16_t'
,
'bf16'
:
'ck_tile::bf16_t'
,
'bf16'
:
'ck_tile::bf16_t'
,
'int8'
:
'ck_tile::int8_t'
}
'int8'
:
'ck_tile::int8_t'
,
'fp8'
:
'ck_tile::fp8_t'
}
def
BOOL_MAP
(
b_
)
->
str
:
def
BOOL_MAP
(
b_
)
->
str
:
if
b_
:
if
b_
:
...
@@ -477,12 +478,13 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
...
@@ -477,12 +478,13 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
h_traits
=
rmsnorm_fwd_codegen
.
h_traits
h_traits
=
rmsnorm_fwd_codegen
.
h_traits
h_instance
=
rmsnorm_fwd_codegen
.
h_instance
h_instance
=
rmsnorm_fwd_codegen
.
h_instance
dynamic_quant_out_dtype
=
[
'int8'
]
dynamic_quant_out_dtype
=
[
'int8'
,
'fp8'
]
# some predefined support range
# some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict
# (prec_i,prec_o) for simplicity this string will be used as key for dict
scale_list
=
[(
'fp32,fp32'
)]
scale_list
=
[(
'fp32,fp32'
)]
dtype_list
=
[(
'fp16,fp16'
),
(
'bf16,bf16'
),
dtype_list
=
[(
'fp16,fp16'
),
(
'bf16,bf16'
),
(
'fp16,int8'
),
(
'bf16,int8'
)]
# NOTE: only fused-dynamic-quant use int8 out
(
'fp16,int8'
),
(
'bf16,int8'
),
(
'fp16,fp8'
),
(
'bf16,fp8'
)]
# NOTE: only fused-dynamic-quant use int8 out
#fused_add_list = [0, 1, 2]
#fused_add_list = [0, 1, 2]
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
fused_add_list
=
[
0
,
1
]
fused_add_list
=
[
0
,
1
]
...
...
example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
View file @
ec959387
...
@@ -105,9 +105,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -105,9 +105,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sy
=
"fp32"
;
prec_sy
=
"fp32"
;
}
}
if
((
fused_quant
==
1
||
fused_quant
==
2
)
&&
prec_o
!=
"int8"
)
if
((
fused_quant
==
1
||
fused_quant
==
2
)
&&
prec_o
!=
"int8"
&&
prec_o
!=
"fp8"
)
{
{
std
::
cout
<<
"if fused_quant is 1, only support
\"
-prec_o=int8
\"
case"
<<
std
::
endl
;
std
::
cout
<<
"if fused_quant is 1 or 2, only support
\"
-prec_o=int8
\"
or
\"
-prec_o=fp8
\"
cases."
<<
std
::
endl
;
return
false
;
return
false
;
}
}
...
@@ -248,7 +250,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -248,7 +250,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
absmax
=
a
>
absmax
?
a
:
absmax
;
absmax
=
a
>
absmax
?
a
:
absmax
;
}
}
// printf("cpu:absmax:%f\n", absmax);
// printf("cpu:absmax:%f\n", absmax);
ComputeDataType
y_scale
=
absmax
/
static_cast
<
ComputeDataType
>
(
127.0
);
constexpr
ComputeDataType
kMaxY
=
std
::
is_same
<
YDataType
,
ck_tile
::
fp8_t
>::
value
?
240.0
:
std
::
is_same
<
YDataType
,
ck_tile
::
int8_t
>::
value
?
127.0
:
0.0
;
ComputeDataType
y_scale
=
absmax
/
kMaxY
;
y_scale_host_ref
(
m_
)
=
ck_tile
::
type_convert
<
YScaleDataType
>
(
y_scale
);
y_scale_host_ref
(
m_
)
=
ck_tile
::
type_convert
<
YScaleDataType
>
(
y_scale
);
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
for
(
int
n_
=
0
;
n_
<
N_
;
n_
++
)
{
{
...
@@ -400,6 +406,16 @@ int main(int argc, char* argv[])
...
@@ -400,6 +406,16 @@ int main(int argc, char* argv[])
{
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
int8_t
,
float
,
float
,
true
>
(
arg_parser
)
?
0
:
-
2
;
}
}
else
if
(
prec_i
==
"fp16"
&&
prec_o
==
"fp8"
&&
prec_sm
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_rms
)
{
return
run
<
ck_tile
::
half_t
,
ck_tile
::
fp8_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
else
if
(
prec_i
==
"bf16"
&&
prec_o
==
"fp8"
&&
prec_sm
==
"fp32"
&&
prec_sy
==
"fp32"
&&
!
save_rms
)
{
return
run
<
ck_tile
::
bf16_t
,
ck_tile
::
fp8_t
,
float
,
float
,
false
>
(
arg_parser
)
?
0
:
-
2
;
}
return
-
3
;
return
-
3
;
}
}
example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
View file @
ec959387
#!/bin/sh
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_rmsnorm2d_fwd
-type
f |
head
-n
1
)
"
EXE
=
"
$(
find
.
-name
tile_rmsnorm2d_fwd
-type
f |
head
-n
1
)
"
for
fquant
in
""
"-fquant=1 -prec_o=int8"
"-fquant=2 -prec_o=int8"
;
do
for
fquant
in
""
"-fquant=1 -prec_o=int8"
"-fquant=2 -prec_o=int8"
"-fquant=1 -prec_o=fp8"
"-fquant=2 -prec_o=fp8"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
for
pr_i
in
"fp16"
"bf16"
;
do
for
fadd
in
"0"
"1"
;
do
for
fadd
in
"0"
"1"
;
do
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
99
-n
=
13
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
99
-n
=
13
...
@@ -27,7 +27,7 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
...
@@ -27,7 +27,7 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
3182
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
3182
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
9
-n
=
4096
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
9
-n
=
4096
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
8192
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
3
-n
=
8192
#
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
$EXE
-prec_i
=
$pr_i
-fadd
=
$fadd
$fquant
-m
=
1
-n
=
10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done
done
...
...
example/ck_tile/13_moe_sorting/moe_sorting.cpp
View file @
ec959387
...
@@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[])
...
@@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[])
.
insert
(
"k"
,
"4"
,
"topk"
)
.
insert
(
"k"
,
"4"
,
"topk"
)
.
insert
(
"unit"
,
"32"
,
"unit_size"
)
.
insert
(
"unit"
,
"32"
,
"unit_size"
)
.
insert
(
"moe_buf_size"
,
"0"
,
"moe_buf_size"
)
.
insert
(
"moe_buf_size"
,
"0"
,
"moe_buf_size"
)
.
insert
(
"local_eid"
,
"-1"
,
"a list of experts enabled as local expert. e.g.
\"
0,1,4,5
\"\n
"
"please make sure eid is in ascending order!"
)
.
insert
(
"seed"
,
"-1"
,
"seed to be used, -1 means random every time"
)
.
insert
(
"seed"
,
"-1"
,
"seed to be used, -1 means random every time"
)
.
insert
(
"kname"
,
"0"
,
"when set to 1 it will print kernel name"
)
.
insert
(
"kname"
,
"0"
,
"when set to 1 it will print kernel name"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
...
@@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int
kname
=
args
.
get_int
(
"kname"
);
int
kname
=
args
.
get_int
(
"kname"
);
int
warmup
=
args
.
get_int
(
"warmup"
);
int
warmup
=
args
.
get_int
(
"warmup"
);
int
repeat
=
args
.
get_int
(
"repeat"
);
int
repeat
=
args
.
get_int
(
"repeat"
);
int
max_output_ids
=
int
max_output_ids
=
ck_tile
::
integer_least_multiple
(
topk
*
tokens
+
num_experts
*
unit_size
-
topk
,
unit_size
);
ck_tile
::
integer_least_multiple
(
topk
*
tokens
+
num_experts
*
unit_size
-
topk
,
unit_size
);
...
@@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args)
return
false
;
return
false
;
}
}
bool
local_expert_masking
=
args
.
get_str
(
"local_eid"
)
!=
"-1"
;
auto
local_expert_masking_host
=
[
&
]()
{
if
(
local_expert_masking
)
{
auto
local_eid
=
args
.
get_int_vec
(
"local_eid"
);
// std::vector<int> v_ {num_experts, 0};
ck_tile
::
HostTensor
<
IndexType
>
v_
{{
num_experts
}};
v_
.
SetZero
();
for
(
auto
eid
:
local_eid
)
{
if
(
eid
>=
num_experts
)
{
throw
std
::
runtime_error
(
"local_eid larger than number of expert, please check"
);
}
v_
.
mData
[
eid
]
=
1
;
}
return
v_
;
}
else
// return std::vector<int>{};
return
ck_tile
::
HostTensor
<
IndexType
>
{{
1
}};
}();
// tokens already considered batch size
// tokens already considered batch size
ck_tile
::
HostTensor
<
IndexType
>
topk_ids_host
({
tokens
,
topk
},
{
topk
,
1
});
ck_tile
::
HostTensor
<
IndexType
>
topk_ids_host
({
tokens
,
topk
},
{
topk
,
1
});
ck_tile
::
HostTensor
<
WeightType
>
weights_host
({
tokens
,
topk
},
{
topk
,
1
});
ck_tile
::
HostTensor
<
WeightType
>
weights_host
({
tokens
,
topk
},
{
topk
,
1
});
...
@@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
sorted_expert_ids_host
.
get_element_space_size_in_bytes
());
sorted_expert_ids_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_id_cnt_dev
(
sorted_id_cnt_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
sorted_id_cnt_dev
(
sorted_id_cnt_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
moe_buf_dev
(
moe_buf_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
moe_buf_dev
(
moe_buf_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
local_expert_masking_dev
(
local_expert_masking_host
.
get_element_space_size_in_bytes
());
topk_ids_dev
.
ToDevice
(
topk_ids_host
.
data
());
topk_ids_dev
.
ToDevice
(
topk_ids_host
.
data
());
weights_dev
.
ToDevice
(
weights_host
.
data
());
weights_dev
.
ToDevice
(
weights_host
.
data
());
...
@@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args)
{
{
moe_buf_dev
.
ToDevice
(
moe_buf_host
.
data
());
moe_buf_dev
.
ToDevice
(
moe_buf_host
.
data
());
}
}
if
(
local_expert_masking
)
local_expert_masking_dev
.
ToDevice
(
local_expert_masking_host
.
data
());
moe_sorting_trait
trait
{
index_prec
,
weight_prec
};
moe_sorting_trait
trait
{
index_prec
,
weight_prec
,
local_expert_masking
};
moe_sorting_args
karg
{
topk_ids_dev
.
GetDeviceBuffer
(),
moe_sorting_args
karg
{
topk_ids_dev
.
GetDeviceBuffer
(),
weights_dev
.
GetDeviceBuffer
(),
weights_dev
.
GetDeviceBuffer
(),
local_expert_masking
?
local_expert_masking_dev
.
GetDeviceBuffer
()
:
nullptr
,
sorted_ids_dev
.
GetDeviceBuffer
(),
sorted_ids_dev
.
GetDeviceBuffer
(),
sorted_weights_dev
.
GetDeviceBuffer
(),
sorted_weights_dev
.
GetDeviceBuffer
(),
sorted_expert_ids_dev
.
GetDeviceBuffer
(),
sorted_expert_ids_dev
.
GetDeviceBuffer
(),
...
@@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args)
warmup
,
warmup
,
repeat
};
repeat
};
auto
ms
=
moe_sorting
(
trait
,
karg
,
sc
);
auto
ms
=
moe_sorting
(
trait
,
karg
,
sc
);
printf
(
"[%s|%s]tokens:%d, num_experts:%d, topk:%d,
ms:%f ,
"
,
printf
(
"[%s|%s]tokens:%d, num_experts:%d, topk:%d, "
,
index_prec
.
c_str
(),
index_prec
.
c_str
(),
weight_prec
.
c_str
(),
weight_prec
.
c_str
(),
tokens
,
tokens
,
num_experts
,
num_experts
,
topk
,
topk
);
ms
);
if
(
local_expert_masking
)
{
printf
(
"local_eid:%s, "
,
args
.
get_str
(
"local_eid"
).
c_str
());
}
if
(
ms
<
0
)
if
(
ms
<
0
)
printf
(
"not supported
\n
"
);
printf
(
"not supported
\n
"
);
else
printf
(
"ms:%f, "
,
ms
);
fflush
(
stdout
);
fflush
(
stdout
);
if
(
ms
<
0
)
if
(
ms
<
0
)
{
{
...
@@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int32_t
ref_total_tokens_post_pad
=
0
;
int32_t
ref_total_tokens_post_pad
=
0
;
ck_tile
::
reference_moe_sorting
<
WeightType
,
IndexType
>
(
topk_ids_host
,
ck_tile
::
reference_moe_sorting
<
WeightType
,
IndexType
>
(
topk_ids_host
,
weights_host
,
weights_host
,
local_expert_masking_host
,
sorted_ids_ref
,
sorted_ids_ref
,
sorted_weights_ref
,
sorted_weights_ref
,
sorted_expert_ids_ref
,
sorted_expert_ids_ref
,
ref_total_tokens_post_pad
,
ref_total_tokens_post_pad
,
num_experts
,
num_experts
,
unit_size
);
unit_size
,
local_expert_masking
);
rtn
&=
ck_tile
::
check_err
(
rtn
&=
ck_tile
::
check_err
(
sorted_ids_host
,
sorted_ids_ref
,
std
::
string
(
"OUT Error: Incorrect ids!"
),
1e-6
,
1e-6
);
sorted_ids_host
,
sorted_ids_ref
,
std
::
string
(
"OUT Error: Incorrect ids!"
),
1e-6
,
1e-6
);
rtn
&=
ck_tile
::
check_err
(
sorted_weights_host
,
rtn
&=
ck_tile
::
check_err
(
sorted_weights_host
,
...
@@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
...
@@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
moe_buf_host
,
moe_buf_ref
,
std
::
string
(
"OUT Error: Incorrect zero buf!"
),
0
,
0
);
moe_buf_host
,
moe_buf_ref
,
std
::
string
(
"OUT Error: Incorrect zero buf!"
),
0
,
0
);
}
}
rtn
&=
ref_total_tokens_post_pad
==
sorted_id_cnt_host
.
mData
[
0
];
rtn
&=
ref_total_tokens_post_pad
==
sorted_id_cnt_host
.
mData
[
0
];
printf
(
"total_tokens_post_pad:%d(%d), "
,
ref_total_tokens_post_pad
,
sorted_id_cnt_host
.
mData
[
0
]);
}
}
printf
(
"valid:%s
\n
"
,
rtn
?
"y"
:
"n"
);
printf
(
"valid:%s"
,
rtn
?
"y"
:
"n"
);
fflush
(
stdout
);
if
(
!
rtn
)
printf
(
", (%d)"
,
seed
);
printf
(
"
\n
"
);
fflush
(
stdout
);
fflush
(
stdout
);
return
rtn
;
return
rtn
;
}
}
...
...
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
View file @
ec959387
...
@@ -3,6 +3,12 @@
...
@@ -3,6 +3,12 @@
#include "moe_sorting_api.hpp"
#include "moe_sorting_api.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
...
@@ -17,6 +23,67 @@
...
@@ -17,6 +23,67 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
constexpr bool local_expert_masking = local_expert_masking_; \
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
if(is_local_expert_masking) \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
}
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
if(a.num_experts <= 8) \
{ \
{ \
...
@@ -38,11 +105,13 @@
...
@@ -38,11 +105,13 @@
{ \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
}
#endif
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_args
a
,
ck_tile
::
stream_config
s
)
float
moe_sorting
(
moe_sorting_trait
t
,
moe_sorting_args
a
,
ck_tile
::
stream_config
s
)
{
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
{
{
#if !MOE_SORTING_USE_EX_KERNEL
if
(
a
.
num_experts
>
127
)
if
(
a
.
num_experts
>
127
)
{
{
printf
(
"lds size exceed, only support experts <127
\n
"
);
printf
(
"lds size exceed, only support experts <127
\n
"
);
...
@@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
...
@@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
MOE_SORTING_DISPATCH
(
4
);
MOE_SORTING_DISPATCH
(
4
);
}
}
}
}
#else
using
index_t
=
ck_tile
::
index_t
;
using
ms_weight_type
=
float
;
auto
[
r_
,
c_
]
=
ck_tile
::
moe_sorting_get_smem_row_col
(
a
.
tokens
,
a
.
num_experts
);
auto
sub_token_
=
r_
-
2
;
r_
=
(
r_
-
2
)
/
8
;
bool
is_sub_token_onshot
=
a
.
tokens
<=
sub_token_
;
bool
is_local_expert_masking
=
t
.
local_expert_masking
;
(
void
)
c_
;
MOE_SORTING_DISPATCH_EMASK_
(
r_
);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
}
}
return
-
1
;
return
-
1
;
}
}
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
View file @
ec959387
...
@@ -10,7 +10,8 @@
...
@@ -10,7 +10,8 @@
struct
moe_sorting_trait
struct
moe_sorting_trait
{
{
std
::
string
index_type
;
std
::
string
index_type
;
std
::
string
weight_type
;
// currently always float
std
::
string
weight_type
;
// currently always float
bool
local_expert_masking
;
// if mask experts as local expert
};
};
struct
moe_sorting_args
:
public
ck_tile
::
MoeSortingHostArgs
struct
moe_sorting_args
:
public
ck_tile
::
MoeSortingHostArgs
...
...
example/ck_tile/13_moe_sorting/script/smoke_test.sh
View file @
ec959387
...
@@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11
...
@@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11
$EXE
-t
=
1
-e
=
1
-k
=
1
$EXE
-t
=
1
-e
=
1
-k
=
1
$EXE
-t
=
99
-e
=
2
-k
=
1
$EXE
-t
=
99
-e
=
2
-k
=
1
$EXE
-t
=
333
-e
=
99
-k
=
13
$EXE
-t
=
333
-e
=
99
-k
=
13
$EXE
-t
=
11
-e
=
256
-k
=
5
$EXE
-t
=
64
-e
=
455
-k
=
8
$EXE
-t
=
777
-e
=
802
-k
=
99
$EXE
-t
=
4097
-e
=
906
-k
=
51
$EXE
-t
=
128
-e
=
32
-k
=
5
-moe_buf_size
=
262144
$EXE
-t
=
128
-e
=
32
-k
=
5
-moe_buf_size
=
262144
$EXE
-t
=
13
-e
=
64
-k
=
3
-local_eid
=
4,5,6,7,8,9,10,11
$EXE
-t
=
99
-e
=
33
-k
=
9
-local_eid
=
6,10,11,15,19
$EXE
-t
=
80
-e
=
99
-k
=
10
-local_eid
=
0,8,12,33
$EXE
-t
=
11
-e
=
256
-k
=
5
-local_eid
=
99,110,129
example/ck_tile/15_fused_moe/README.md
View file @
ec959387
...
@@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator:
...
@@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator:
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
//
// max_num_tokens_padded : topk * input_tokens + num_experts *
(
M_a -
1
)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a -
topk (updated
)
// * this could be larger than actual, since actual tokens are on GPU
// * this could be larger than actual, since actual tokens are on GPU
//
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...
...
example/ck_tile/15_fused_moe/fused_moe.hpp
View file @
ec959387
...
@@ -8,14 +8,15 @@
...
@@ -8,14 +8,15 @@
struct
fused_moe_args
struct
fused_moe_args
{
{
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
a_scale_ptr
;
// [m, 1], token scale
const
void
*
a_scale_ptr
;
// [m, 1], token scale
const
void
*
g_ptr
;
// [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const
void
*
g_ptr
;
// [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const
void
*
d_ptr
;
// [e, n, k], pre-shuffle([e, nr, kr, w])
const
void
*
d_ptr
;
// [e, n, k], pre-shuffle([e, nr, kr, w])
const
void
*
g_scale_ptr
;
// [e, 1, n], gate(up) scale
const
void
*
g_scale_ptr
;
// [e, 1, n], gate(up) scale
const
void
*
d_scale_ptr
;
// [e, 1, k], down scale
const
void
*
d_scale_ptr
;
// [e, 1, k], down scale
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
void
*
o_ptr
;
// [m, k], output token (no need to do zeroing)
const
void
*
local_expert_mask_ptr
;
// [e], local_expert_mask_ptr for EP
void
*
o_ptr
;
// [m, k], output token (no need to do zeroing)
const
void
*
topk_ids_ptr
;
// [tokens, topk]
const
void
*
topk_ids_ptr
;
// [tokens, topk]
const
void
*
topk_weight_ptr
;
// [tokens, topk]
const
void
*
topk_weight_ptr
;
// [tokens, topk]
...
@@ -48,6 +49,8 @@ struct fused_moe_traits
...
@@ -48,6 +49,8 @@ struct fused_moe_traits
int
activation
;
// 0:gelu, 1:silu
int
activation
;
// 0:gelu, 1:silu
int
gate_only
;
// 0:g1u0, 1:g1u1
int
gate_only
;
// 0:g1u0, 1:g1u1
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
bool
local_expert_masking
;
// if mask experts as local expert
};
};
float
fused_moe
(
fused_moe_traits
,
fused_moe_args
,
const
ck_tile
::
stream_config
&
);
float
fused_moe
(
fused_moe_traits
,
fused_moe_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/15_fused_moe/fused_moesorting.hpp
View file @
ec959387
...
@@ -10,7 +10,8 @@
...
@@ -10,7 +10,8 @@
struct
fused_moesorting_trait
struct
fused_moesorting_trait
{
{
std
::
string
index_type
;
std
::
string
index_type
;
std
::
string
weight_type
;
// currently always float
std
::
string
weight_type
;
// currently always float
bool
local_expert_masking
;
// if mask experts as local expert
};
};
struct
fused_moesorting_args
:
public
ck_tile
::
MoeSortingHostArgs
struct
fused_moesorting_args
:
public
ck_tile
::
MoeSortingHostArgs
...
...
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
View file @
ec959387
...
@@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
...
@@ -17,10 +17,11 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
return
1
;
return
1
;
}();
}();
auto
t0
=
fused_moesorting_trait
{
"int32"
,
"fp32"
};
auto
t0
=
fused_moesorting_trait
{
"int32"
,
"fp32"
,
t
.
local_expert_masking
};
auto
a0
=
fused_moesorting_args
{
auto
a0
=
fused_moesorting_args
{
a
.
topk_ids_ptr
,
// const void* p_topk_ids;
a
.
topk_ids_ptr
,
// const void* p_topk_ids;
a
.
topk_weight_ptr
,
// const void* p_weights;
a
.
topk_weight_ptr
,
// const void* p_weights;
a
.
local_expert_mask_ptr
,
// const void* p_local_expert_mask;
a
.
sorted_token_ids_ptr
,
// void* p_sorted_token_ids;
a
.
sorted_token_ids_ptr
,
// void* p_sorted_token_ids;
a
.
sorted_weight_ptr
,
// void* p_sorted_weights;
a
.
sorted_weight_ptr
,
// void* p_sorted_weights;
a
.
sorted_expert_ids_ptr
,
// void* p_sorted_expert_ids;
a
.
sorted_expert_ids_ptr
,
// void* p_sorted_expert_ids;
...
...
example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
View file @
ec959387
...
@@ -3,6 +3,12 @@
...
@@ -3,6 +3,12 @@
#include "fused_moesorting.hpp"
#include "fused_moesorting.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \
...
@@ -17,6 +23,67 @@
...
@@ -17,6 +23,67 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
constexpr bool local_expert_masking = local_expert_masking_; \
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
if(is_local_expert_masking) \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
}
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
if(a.num_experts <= 8) \
{ \
{ \
...
@@ -38,11 +105,13 @@
...
@@ -38,11 +105,13 @@
{ \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
}
#endif
float
fused_moesorting
(
fused_moesorting_trait
t
,
fused_moesorting_args
a
,
ck_tile
::
stream_config
s
)
float
fused_moesorting
(
fused_moesorting_trait
t
,
fused_moesorting_args
a
,
ck_tile
::
stream_config
s
)
{
{
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
if
(
t
.
weight_type
==
"fp32"
&&
t
.
index_type
==
"int32"
)
{
{
#if !MOE_SORTING_USE_EX_KERNEL
if
(
a
.
num_experts
>
127
)
if
(
a
.
num_experts
>
127
)
{
{
printf
(
"lds size exceed, only support experts <127
\n
"
);
printf
(
"lds size exceed, only support experts <127
\n
"
);
...
@@ -83,6 +152,19 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
...
@@ -83,6 +152,19 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
MOE_SORTING_DISPATCH
(
4
);
MOE_SORTING_DISPATCH
(
4
);
}
}
}
}
#else
using
index_t
=
ck_tile
::
index_t
;
using
ms_weight_type
=
float
;
auto
[
r_
,
c_
]
=
ck_tile
::
moe_sorting_get_smem_row_col
(
a
.
tokens
,
a
.
num_experts
);
auto
sub_token_
=
r_
-
2
;
r_
=
(
r_
-
2
)
/
8
;
bool
is_sub_token_onshot
=
a
.
tokens
<=
sub_token_
;
bool
is_local_expert_masking
=
t
.
local_expert_masking
;
(
void
)
c_
;
MOE_SORTING_DISPATCH_EMASK_
(
r_
);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
}
}
return
-
1
;
return
-
1
;
}
}
example/ck_tile/15_fused_moe/main.cpp
View file @
ec959387
...
@@ -140,28 +140,29 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -140,28 +140,29 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
index_t
activation
=
arg_parser
.
get_int
(
"act"
);
ck_tile
::
index_t
activation
=
arg_parser
.
get_int
(
"act"
);
if
(
stride
<
0
)
if
(
stride
<
0
)
stride
=
hidden_size
;
stride
=
hidden_size
;
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_w
=
arg_parser
.
get_str
(
"prec_w"
);
std
::
string
prec_w
=
arg_parser
.
get_str
(
"prec_w"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_st
=
arg_parser
.
get_str
(
"prec_st"
);
std
::
string
prec_st
=
arg_parser
.
get_str
(
"prec_st"
);
std
::
string
prec_sw
=
arg_parser
.
get_str
(
"prec_sw"
);
std
::
string
prec_sw
=
arg_parser
.
get_str
(
"prec_sw"
);
std
::
string
prec_sq
=
arg_parser
.
get_str
(
"prec_sq"
);
std
::
string
prec_sq
=
arg_parser
.
get_str
(
"prec_sq"
);
std
::
string
prec_kw
=
arg_parser
.
get_str
(
"prec_kw"
);
std
::
string
prec_kw
=
arg_parser
.
get_str
(
"prec_kw"
);
prec_st
=
(
prec_st
==
"auto"
)
?
"fp32"
:
prec_st
;
prec_st
=
(
prec_st
==
"auto"
)
?
"fp32"
:
prec_st
;
prec_sw
=
(
prec_sw
==
"auto"
)
?
"fp32"
:
prec_sw
;
prec_sw
=
(
prec_sw
==
"auto"
)
?
"fp32"
:
prec_sw
;
prec_sq
=
(
prec_sq
==
"auto"
)
?
"fp32"
:
prec_sq
;
prec_sq
=
(
prec_sq
==
"auto"
)
?
"fp32"
:
prec_sq
;
prec_kw
=
(
prec_kw
==
"auto"
)
?
"fp32"
:
prec_kw
;
prec_kw
=
(
prec_kw
==
"auto"
)
?
"fp32"
:
prec_kw
;
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
int
gate_only
=
arg_parser
.
get_int
(
"gate_only"
);
int
gate_only
=
arg_parser
.
get_int
(
"gate_only"
);
int
api
=
arg_parser
.
get_int
(
"api"
);
int
api
=
arg_parser
.
get_int
(
"api"
);
int
balance
=
arg_parser
.
get_int
(
"balance"
);
int
balance
=
arg_parser
.
get_int
(
"balance"
);
int
tp
=
arg_parser
.
get_int
(
"tp"
);
int
tp
=
arg_parser
.
get_int
(
"tp"
);
int
init
=
arg_parser
.
get_int
(
"init"
);
int
init
=
arg_parser
.
get_int
(
"init"
);
uint32_t
seed
=
arg_parser
.
get_uint32
(
"seed"
);
uint32_t
seed
=
arg_parser
.
get_uint32
(
"seed"
);
bool
local_expert_masking
=
false
;
// TODO...
// w0 (Gate+Up or Gate only, N size)
// w0 (Gate+Up or Gate only, N size)
ck_tile
::
index_t
shared_intermediate_size_0
=
intermediate_size
*
(
gate_only
?
1
:
2
)
/
tp
;
ck_tile
::
index_t
shared_intermediate_size_0
=
intermediate_size
*
(
gate_only
?
1
:
2
)
/
tp
;
...
@@ -230,6 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -230,6 +231,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>
sy_host
({
shared_intermediate_size_1
});
// smooth-quant
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>
sy_host
({
shared_intermediate_size_1
});
// smooth-quant
ck_tile
::
HostTensor
<
IndexDataType
>
topk_ids_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
IndexDataType
>
topk_ids_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
TopkWeightDataType
>
topk_weight_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
TopkWeightDataType
>
topk_weight_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
IndexDataType
>
local_expert_mask_host
({
experts
});
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
block_m
-
topk
;
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
block_m
-
topk
;
ck_tile
::
HostTensor
<
IndexDataType
>
sorted_token_ids_host
({
max_num_tokens_padded
});
ck_tile
::
HostTensor
<
IndexDataType
>
sorted_token_ids_host
({
max_num_tokens_padded
});
...
@@ -355,6 +357,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -355,6 +357,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
sg_buf
(
sg_host
);
ck_tile
::
DeviceMem
sg_buf
(
sg_host
);
ck_tile
::
DeviceMem
sd_buf
(
sd_host
);
ck_tile
::
DeviceMem
sd_buf
(
sd_host
);
ck_tile
::
DeviceMem
sy_buf
(
sy_host
);
ck_tile
::
DeviceMem
sy_buf
(
sy_host
);
ck_tile
::
DeviceMem
local_expert_mask_buf
(
local_expert_mask_host
);
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
topk_ids_buf
(
topk_ids_host
);
ck_tile
::
DeviceMem
topk_ids_buf
(
topk_ids_host
);
...
@@ -378,7 +381,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -378,7 +381,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
block_m
,
block_m
,
activation
,
activation
,
gate_only
,
gate_only
,
fused_quant
};
fused_quant
,
local_expert_masking
};
fused_moe_args
args
{
a_buf
.
GetDeviceBuffer
(),
fused_moe_args
args
{
a_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sa_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sa_buf
.
GetDeviceBuffer
()
:
nullptr
,
...
@@ -387,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -387,6 +391,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
fused_quant
!=
0
?
sg_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sg_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sd_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
!=
0
?
sd_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
sy_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
==
1
?
sy_buf
.
GetDeviceBuffer
()
:
nullptr
,
local_expert_masking
?
local_expert_mask_buf
.
GetDeviceBuffer
()
:
nullptr
,
o_buf
.
GetDeviceBuffer
(),
o_buf
.
GetDeviceBuffer
(),
topk_ids_buf
.
GetDeviceBuffer
(),
topk_ids_buf
.
GetDeviceBuffer
(),
topk_weight_buf
.
GetDeviceBuffer
(),
topk_weight_buf
.
GetDeviceBuffer
(),
...
@@ -442,12 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -442,12 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
topk_ids_host
,
topk_ids_host
,
topk_weight_host
,
topk_weight_host
,
local_expert_mask_host
,
sorted_token_ids_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
.
mData
[
0
],
num_sorted_tiles_host
.
mData
[
0
],
experts
,
experts
,
block_m
);
block_m
,
local_expert_masking
);
if
(
activation
==
0
)
if
(
activation
==
0
)
{
{
CPU_FUSED_MOE
(
ck_tile
::
element_wise
::
Gelu
);
CPU_FUSED_MOE
(
ck_tile
::
element_wise
::
Gelu
);
...
@@ -472,12 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -472,12 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
topk_ids_host
,
topk_ids_host
,
topk_weight_host
,
topk_weight_host
,
local_expert_mask_host
,
sorted_token_ids_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
.
mData
[
0
],
num_sorted_tiles_host
.
mData
[
0
],
experts
,
experts
,
block_m
);
block_m
,
local_expert_masking
);
// done, preparing GPU buffer
// done, preparing GPU buffer
ck_tile
::
DeviceMem
a_buf
(
a_host
);
ck_tile
::
DeviceMem
a_buf
(
a_host
);
...
...
example/ck_tile/16_batched_gemm/batched_gemm.cpp
View file @
ec959387
...
@@ -19,12 +19,9 @@ template <typename ALayout, typename BLayout, typename CLayout>
...
@@ -19,12 +19,9 @@ template <typename ALayout, typename BLayout, typename CLayout>
float
batched_gemm
(
const
ck_tile
::
BatchedGemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
batched_gemm
(
const
ck_tile
::
BatchedGemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
...
@@ -41,40 +38,31 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
...
@@ -41,40 +38,31 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
CodegenGemmShape
=
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
CodegenGemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile1DPartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
using
CodegenGemmTraits
=
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPolicy
=
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
CodegenGemmPipeline
=
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
,
CodegenGemmPolicy
>
;
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
CodegenPipelineProblem
::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
M_Warp
,
N_Warp
,
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
,
CodegenPipelineProblem
::
TransposeC
>>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
...
@@ -91,8 +79,11 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
...
@@ -91,8 +79,11 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
if
(
s
.
log_level_
>
0
)
if
(
s
.
log_level_
>
0
)
{
{
std
::
cout
<<
"Launching kernel with args:"
std
::
cout
<<
"Launching kernel with args: "
<<
Kernel
::
GetName
()
<<
'\n'
<<
" grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
"shape: "
<<
CodegenGemmShape
::
GetName
()
<<
'\n'
<<
"problem: "
<<
CodegenPipelineProblem
::
GetName
()
<<
'\n'
<<
"pipeline: "
<<
CodegenGemmPipeline
::
GetName
()
<<
'\n'
<<
"grid: {"
<<
grids
.
x
<<
", "
<<
grids
.
y
<<
", "
<<
grids
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
", blocks: {"
<<
blocks
.
x
<<
", "
<<
blocks
.
y
<<
", "
<<
blocks
.
z
<<
"}"
<<
std
::
endl
;
<<
std
::
endl
;
}
}
...
...
example/ck_tile/16_batched_gemm/batched_gemm.hpp
View file @
ec959387
...
@@ -39,7 +39,7 @@ auto create_args(int argc, char* argv[])
...
@@ -39,7 +39,7 @@ auto create_args(int argc, char* argv[])
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"stride_c"
,
"0"
,
"Tensor C stride"
)
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
.
insert
(
"b_layout"
,
"
R
"
,
"B tensor data layout - Row by default"
)
.
insert
(
"b_layout"
,
"
C
"
,
"B tensor data layout - Row by default"
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default"
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default"
)
.
insert
(
"batch_stride_a"
,
"32768"
,
"Batch A stride"
)
.
insert
(
"batch_stride_a"
,
"32768"
,
"Batch A stride"
)
.
insert
(
"batch_stride_b"
,
"16384"
,
"Batch B stride"
)
.
insert
(
"batch_stride_b"
,
"16384"
,
"Batch B stride"
)
...
...
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
View file @
ec959387
...
@@ -212,7 +212,7 @@ int run_batched_gemm_example_with_layouts(int argc,
...
@@ -212,7 +212,7 @@ int run_batched_gemm_example_with_layouts(int argc,
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
" Absolute error threshold: "
<<
rtol_atol
.
at
(
ck_tile
::
number
<
1
>
{})
<<
std
::
endl
;
<<
std
::
endl
;
std
::
cout
<<
"The CPU veification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
std
::
cout
<<
"The CPU ve
r
ification result is:"
<<
(
pass
?
"correct"
:
"fail"
)
<<
std
::
endl
;
}
}
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
else
if
(
arg_parser
.
get_int
(
"v"
)
==
2
)
{
{
...
@@ -301,11 +301,11 @@ int run_batched_gemm_example(int argc, char* argv[])
...
@@ -301,11 +301,11 @@ int run_batched_gemm_example(int argc, char* argv[])
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
//
if(a_layout == "R" && b_layout == "R")
{
//
{
return
run_batched_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
//
return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
}
//
}
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
{
return
run_batched_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
return
run_batched_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
}
...
...
Prev
1
2
3
4
5
6
7
8
9
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment