Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
70fa98ad
Commit
70fa98ad
authored
Nov 05, 2024
by
carlushuang
Browse files
update code
parent
7c81aee8
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
575 additions
and
200 deletions
+575
-200
example/ck_tile/15_fused_moe/CMakeLists.txt
example/ck_tile/15_fused_moe/CMakeLists.txt
+15
-0
example/ck_tile/15_fused_moe/fused_moegemm.hpp
example/ck_tile/15_fused_moe/fused_moegemm.hpp
+27
-25
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
+35
-0
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
...ile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
+46
-0
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
..._tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
+50
-0
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+222
-114
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-0
include/ck_tile/host/device_memory.hpp
include/ck_tile/host/device_memory.hpp
+30
-0
include/ck_tile/host/reference/reference_moe_sorting.hpp
include/ck_tile/host/reference/reference_moe_sorting.hpp
+78
-0
include/ck_tile/host/reference/reference_permute.hpp
include/ck_tile/host/reference/reference_permute.hpp
+3
-4
include/ck_tile/ops/fused_moe.hpp
include/ck_tile/ops/fused_moe.hpp
+7
-7
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+32
-29
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
.../ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
+10
-9
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
...sed_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
+6
-3
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
+2
-2
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
...e/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
+11
-7
No files found.
example/ck_tile/15_fused_moe/CMakeLists.txt
0 → 100644
View file @
70fa98ad
set
(
TILE_EXAPMLE_FUSED_MOE
"tile_example_fused_moe"
)
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message
(
"adding
${
TILE_EXAPMLE_FUSED_MOE
}
"
)
file
(
GLOB INSTANCE_SRCS instances/*.cpp
)
add_executable
(
${
TILE_EXAPMLE_FUSED_MOE
}
EXCLUDE_FROM_ALL main.cpp
)
target_include_directories
(
${
TILE_EXAPMLE_FUSED_MOE
}
PRIVATE
${
CMAKE_CURRENT_LIST_DIR
}
)
target_sources
(
${
TILE_EXAPMLE_FUSED_MOE
}
PRIVATE
${
INSTANCE_SRCS
}
)
set
(
TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS
)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list
(
APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal
)
target_compile_options
(
${
TILE_EXAPMLE_FUSED_MOE
}
PRIVATE
${
TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS
}
)
example/ck_tile/15_fused_moe/fused_moegemm.hpp
View file @
70fa98ad
...
@@ -16,33 +16,33 @@ struct FusedMoeGemmTypeConfig;
...
@@ -16,33 +16,33 @@ struct FusedMoeGemmTypeConfig;
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FusedMoeGemmTypeConfig
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ST
,
SW
,
SQ
,
KW
>
;
struct
FusedMoeGemmTypeConfig
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ST
,
SW
,
SQ
,
KW
>
;
{
{
using
ADataType
=
ck_tile
::
bf16_t
;
using
ADataType
=
ck_tile
::
bf16_t
;
using
GDataType
=
ck_tile
::
bf16_t
;
using
GDataType
=
ck_tile
::
bf16_t
;
using
DDataType
=
ck_tile
::
bf16_t
;
using
DDataType
=
ck_tile
::
bf16_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
ODataType
=
ck_tile
::
bf16_t
;
using
ODataType
=
ck_tile
::
bf16_t
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
ST
>
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
ST
>
;
using
W0
ScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
G
ScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
W1
ScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
D
ScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
SQ
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
SQ
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
KW
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
KW
>
;
using
IndexDataType
=
ck_tile
::
index_t
;
using
IndexDataType
=
ck_tile
::
index_t
;
};
};
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
template
<
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
struct
FusedMoeGemmTypeConfig
<
ck_tile
::
int8_t
,
ck_tile
::
int8_t
,
ck_tile
::
bf16_t
,
ST
,
SW
,
SQ
,
KW
>
;
struct
FusedMoeGemmTypeConfig
<
ck_tile
::
int8_t
,
ck_tile
::
int8_t
,
ck_tile
::
bf16_t
,
ST
,
SW
,
SQ
,
KW
>
;
{
{
using
ADataType
=
ck_tile
::
int8_t
;
using
ADataType
=
ck_tile
::
int8_t
;
using
GDataType
=
ck_tile
::
int8_t
;
using
GDataType
=
ck_tile
::
int8_t
;
using
DDataType
=
ck_tile
::
int8_t
;
using
DDataType
=
ck_tile
::
int8_t
;
using
AccDataType
=
int32_t
;
using
AccDataType
=
int32_t
;
using
ODataType
=
ck_tile
::
bf16_t
;
using
ODataType
=
ck_tile
::
bf16_t
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
ST
>
;
using
AScaleDataType
=
ck_tile
::
remove_cvref_t
<
ST
>
;
using
W0
ScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
G
ScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
W1
ScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
D
ScaleDataType
=
ck_tile
::
remove_cvref_t
<
SW
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
SQ
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
SQ
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
KW
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
KW
>
;
using
IndexDataType
=
ck_tile
::
index_t
;
using
IndexDataType
=
ck_tile
::
index_t
;
};
};
// runtime args
// runtime args
...
@@ -53,14 +53,16 @@ struct fused_moegemm_args : public ck_tile::Layernorm2dFwdHostArgs
...
@@ -53,14 +53,16 @@ struct fused_moegemm_args : public ck_tile::Layernorm2dFwdHostArgs
// This is the public API, will be generated by script
// This is the public API, will be generated by script
struct
fused_moegemm_traits
struct
fused_moegemm_traits
{
{
std
::
string
prec_i
;
// input precision
std
::
string
prec_i
;
// input precision
std
::
string
prec_w
;
// weight precision
std
::
string
prec_w
;
// weight precision
std
::
string
prec_o
;
// output precision
std
::
string
prec_o
;
// output precision
std
::
string
prec_st
;
// token scale data type
std
::
string
prec_st
;
// token scale data type
std
::
string
prec_sw
;
// weight scale data type
std
::
string
prec_sw
;
// weight scale data type
std
::
string
prec_sq
;
// smooth quant scale
std
::
string
prec_sq
;
// smooth quant scale
std
::
string
prec_kw
;
// topk-weight data type
std
::
string
prec_kw
;
// topk-weight data type
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
int
block_m
;
int
gate_only
;
int
fused_quant
;
// 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
};
float
fused_moegemm
(
fused_moegemm_traits
,
fused_moegemm_args
,
const
ck_tile
::
stream_config
&
);
float
fused_moegemm
(
fused_moegemm_traits
,
fused_moegemm_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
0 → 100644
View file @
70fa98ad
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
// Note: this internal API only declare, not define here, otherwise will block `make -j`
template
<
typename
Traits_
>
float
fused_moegemm_
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
float
fused_moegemm
(
fused_moegemm_traits
t
,
fused_moegemm_args
a
,
const
ck_tile
::
stream_config
&
s
)
{
template
<
ck_tile
::
index_t
...
Is
>
using
S
=
ck_tile
::
sequence
<
Is
...
>
;
float
r
=
-
1
;
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
block_m
==
32
&&
gate_only
==
1
)
{
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
4
,
1
,
1
>
,
S
<
32
,
32
,
16
>
,
1
,
0
>
;
fused_moegemm_
<
t_
>
(
s
,
a
);
}
return
r
;
}
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
0 → 100644
View file @
70fa98ad
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fused_moegemm_api_traits.hpp"
#include "ck_tile/ops/fused_moe.hpp"
template
<
typename
Ts_
>
float
fused_moegemm_
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
)
{
using
f_traits
=
ck_tile
::
FusedMoeGemmTraits
<
Ts_
::
GateOnly
,
Ts_
::
FusedQuant
==
1
,
1
/*atomic*/
>
;
using
f_shape
=
ck_tile
::
FusedMoeGemmShape
<
typename
Ts_
::
BlockTile_0
,
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts
::
WarpTile_0
,
typename
Ts_
::
BlockTile_1
,
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts
::
WarpTile_0
>
;
using
f_problem
=
ck_tile
::
FusedMoeGemmPipelineProblem
<
typename
Ts_
::
ADataType
,
typename
Ts_
::
GDataType
,
typename
Ts_
::
DDataType
,
typename
Ts_
::
AccDataType
,
typename
Ts_
::
ODataType
,
typename
Ts_
::
AScaleDataType
,
typename
Ts_
::
GScaleDataType
,
typename
Ts_
::
DScaleDataType
,
typename
Ts_
::
YSmoothScaleDataType
,
typename
Ts_
::
TopkWeightDataType
,
typename
Ts_
::
IndexDataType
,
ck_tile
::
Gelu
,
// TODO: hardcoded
f_shape
,
f_traits
>
using
f_pipeline
=
ck_tile
::
FusedMoeGemmPipeline_Flatmm
<
f_problem
>
;
using
f_partitioner
=
ck_tile
::
FusedMoeGemmTilePartitioner_Linear
<
f_shape
>
;
using
f_kernel
=
ck_tile
::
FusedMoeGemmKernel
<
f_partitioner
,
f_pipeline
,
void
>
;
const
dim3
grids
=
f_kernel
::
GridSize
(
a
);
constexpr
dim3
blocks
=
f_kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
1
;
auto
kargs
=
f_kernel
::
MakeKargs
(
a
);
if
(
s
.
log_level_
>
0
)
std
::
cout
<<
", "
<<
f_kernel
::
GetName
()
<<
std
::
flush
;
return
ck_tile
::
launch_kernel
(
s
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
f_kernel
{},
grids
,
blocks
,
0
,
kargs
));
}
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
0 → 100644
View file @
70fa98ad
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
typename
I
,
typename
W
,
typename
O
,
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
,
typename
BlockTIle_
,
// seq<b_token, b_interm, b_hidden, b_down>
typename
WarpPerBlock_
,
typename
WarpTile_
,
// seq<*,*,*>, used to select mfma
ck_tile
::
index_t
GateOnly_
=
0
,
ck_tile
::
index_t
FusedQuant_
=
0
>
struct
fmoe_
// traits, ugly name, only used for internal
{
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
ADataType
=
remove_cvref_t
<
typename
TypeConfig
::
ADataType
>
;
using
GDataType
=
remove_cvref_t
<
typename
TypeConfig
::
GDataType
>
;
using
DDataType
=
remove_cvref_t
<
typename
TypeConfig
::
DDataType
>
;
using
AccDataType
=
remove_cvref_t
<
typename
TypeConfig
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
TypeConfig
::
ODataType
>
;
using
AScaleDataType
=
remove_cvref_t
<
typename
TypeConfig
::
AScaleDataType
>
;
using
GScaleDataType
=
remove_cvref_t
<
typename
TypeConfig
::
GScaleDataType
>
;
using
DScaleDataType
=
remove_cvref_t
<
typename
TypeConfig
::
DScaleDataType
>
;
using
YSmoothScaleDataType
=
remove_cvref_t
<
typename
TypeConfig
::
YSmoothScaleDataType
>
;
using
TopkWeightDataType
=
remove_cvref_t
<
typename
TypeConfig
::
TopkWeightDataType
>
;
using
IndexDataType
=
remove_cvref_t
<
typename
TypeConfig
::
IndexDataType
>
;
static
constexpr
index_t
BT_
=
BlockTIle_
::
at
(
number
<
0
>
{});
// block token
static
constexpr
index_t
BI_
=
BlockTIle_
::
at
(
number
<
1
>
{});
// block intermediate
static
constexpr
index_t
BH_
=
BlockTIle_
::
at
(
number
<
2
>
{});
// block hidden
static
constexpr
index_t
BD_
=
BlockTIle_
::
at
(
number
<
3
>
{});
// block down
using
BlockTile_0
=
ck_tile
::
sequence
<
BT_
,
BI_
,
BH_
>
;
using
WarpPerBlock_0
=
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpTile_0
=
remove_cvref_t
<
WarpTile_
>
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
/
(
GateOnly_
?
1
:
2
)
>
;
using
WarpPerBlock_1
=
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpTile_1
=
remove_cvref_t
<
WarpTile_
>
;
static
constexpr
ck_tile
::
index_t
GateOnly
=
GateOnly_
;
static
constexpr
ck_tile
::
index_t
FusedQuant
=
FusedQuant_
;
};
example/ck_tile/15_fused_moe/main.cpp
View file @
70fa98ad
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "
layernorm2d_fwd
.hpp"
#include "
fused_moegemm
.hpp"
#include <algorithm>
#include <algorithm>
#include <cstring>
#include <cstring>
#include <unordered_set>
#include <vector>
#include <set>
// different threshold for different dtype
// different threshold for different dtype
template
<
typename
DataType
>
template
<
typename
DataType
>
...
@@ -20,18 +23,64 @@ auto get_elimit<ck_tile::bf16_t>()
...
@@ -20,18 +23,64 @@ auto get_elimit<ck_tile::bf16_t>()
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
}
// mfma_type, 0:32x32, 1:16x16
// mfma_type, 0:32x32, 1:16x16
template
<
typename
H
>
// TODO: padding?
auto
shuffle_moe_weight
(
const
H
&
t
,
std
::
string
mfma_dtype
,
int
mfma_type
=
0
)
template
<
typename
T
>
auto
shuffle_moe_weight
(
const
ck_tile
::
HostTensor
<
T
>&
t
,
std
::
string
mfma_dtype
,
int
mfma_type
=
0
)
{
{
static_assert
(
t
.
get_lengths
().
size
()
==
3
);
static_assert
(
t
.
get_lengths
().
size
()
==
3
);
int
b_
=
t
.
get_lengths
()[
0
];
int
b_
=
t
.
get_lengths
()[
0
];
int
n_
=
t
.
get_lengths
()[
1
];
int
n_
=
t
.
get_lengths
()[
1
];
int
k_
=
t
.
get_lengths
()[
2
];
int
k_
=
t
.
get_lengths
()[
2
];
if
((
mfma_dtype
==
"bf16"
||
mfma_dtype
==
"fp16"
)
&&
mfma_type
==
0
)
{
if
((
mfma_dtype
==
"bf16"
||
mfma_dtype
==
"fp16"
)
&&
mfma_type
==
0
)
std
::
vector
<
ck_tile
::
index_t
>
new_lens
{
b_
,
n_
/
32
,
32
,
k_
/
16
,
2
,
8
};
{
ck_tile
::
HostTensor
<
T
>
t_view
({
b_
,
n_
/
32
,
32
,
k_
/
16
,
2
,
8
});
std
::
copy
(
t
.
begin
(),
t
.
end
(),
t_view
.
begin
());
return
ck_tile
::
reference_permute
(
t_view
,
{
0
,
1
,
3
,
4
,
2
,
5
});
}
else
if
((
mfma_dtype
==
"bf16"
||
mfma_dtype
==
"fp16"
)
&&
mfma_type
==
1
)
{
ck_tile
::
HostTensor
<
T
>
t_view
({
b_
,
n_
/
16
,
16
,
k_
/
32
,
4
,
8
});
std
::
copy
(
t
.
begin
(),
t
.
end
(),
t_view
.
begin
());
return
ck_tile
::
reference_permute
(
t_view
,
{
0
,
1
,
3
,
4
,
2
,
5
});
}
else
if
((
mfma_dtype
==
"int8"
||
mfma_dtype
==
"fp8"
)
&&
mfma_type
==
0
)
{
ck_tile
::
HostTensor
<
T
>
t_view
({
b_
,
n_
/
32
,
32
,
k_
/
32
,
2
,
16
});
std
::
copy
(
t
.
begin
(),
t
.
end
(),
t_view
.
begin
());
return
ck_tile
::
reference_permute
(
t_view
,
{
0
,
1
,
3
,
4
,
2
,
5
});
}
else
if
((
mfma_dtype
==
"int8"
||
mfma_dtype
==
"fp8"
)
&&
mfma_type
==
1
)
{
ck_tile
::
HostTensor
<
T
>
t_view
({
b_
,
n_
/
16
,
16
,
k_
/
64
,
4
,
16
});
std
::
copy
(
t
.
begin
(),
t
.
end
(),
t_view
.
begin
());
return
ck_tile
::
reference_permute
(
t_view
,
{
0
,
1
,
3
,
4
,
2
,
5
});
}
return
t
;
}
}
template
<
typename
IndexType
>
void
topid_unique_gen
(
std
::
vector
<
IndexType
>&
host_tensor
,
int
tokens
,
int
topk
,
int
num_expert
,
int
seed
)
{
size_t
total_size
=
topk
*
tokens
;
std
::
srand
(
seed
);
std
::
set
<
IndexType
>
unique_set
;
IndexType
current_v
;
for
(
size_t
i
=
0
;
i
<
total_size
;
i
++
)
{
if
(
i
%
topk
==
0
)
{
unique_set
.
clear
();
}
current_v
=
std
::
rand
()
%
num_expert
;
while
(
unique_set
.
find
(
current_v
)
!=
unique_set
.
end
())
{
current_v
=
std
::
rand
()
%
num_expert
;
}
unique_set
.
insert
(
current_v
);
host_tensor
[
i
]
=
current_v
;
}
}
}
auto
create_args
(
int
argc
,
char
*
argv
[])
auto
create_args
(
int
argc
,
char
*
argv
[])
...
@@ -55,8 +104,11 @@ auto create_args(int argc, char* argv[])
...
@@ -55,8 +104,11 @@ auto create_args(int argc, char* argv[])
.
insert
(
"prec_sq"
,
"auto"
,
"(dynamic) smooth quant data type. auto will set to fp32"
)
.
insert
(
"prec_sq"
,
"auto"
,
"(dynamic) smooth quant data type. auto will set to fp32"
)
.
insert
(
"prec_kw"
,
"auto"
,
"topk-weight data type. auto will set to fp32"
)
.
insert
(
"prec_kw"
,
"auto"
,
"topk-weight data type. auto will set to fp32"
)
.
insert
(
"fquant"
,
"0"
,
"fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant"
)
.
insert
(
"fquant"
,
"0"
,
"fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant"
)
.
insert
(
"gonly"
,
"0"
,
"w0(gate/up) style, 0:gate+up will double interm size, 1:only gate"
)
.
insert
(
.
insert
(
"balance"
,
"1"
,
"if set to 1, will try balance the expert in topk-ids(convenient for testing)"
)
"gate_only"
,
"0"
,
"w0(gate/up) style, 0:gate+up will double interm size, 1:only gate"
)
.
insert
(
"balance"
,
"1"
,
"if set to 1, will try balance the expert in topk-ids(convenient for testing)"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
...
@@ -64,133 +116,178 @@ auto create_args(int argc, char* argv[])
...
@@ -64,133 +116,178 @@ auto create_args(int argc, char* argv[])
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
}
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type, SQ:smooth-quant-type, KW:topk-weight-type
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type,
// SQ:smooth-quant-type, KW:topk-weight-type
template
<
typename
I
,
typename
W
,
typename
O
,
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
template
<
typename
I
,
typename
W
,
typename
O
,
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
{
ck_tile
::
index_t
tokens
=
arg_parser
.
get_int
(
"t"
);
ck_tile
::
index_t
tokens
=
arg_parser
.
get_int
(
"t"
);
ck_tile
::
index_t
experts
=
arg_parser
.
get_int
(
"e"
);
ck_tile
::
index_t
experts
=
arg_parser
.
get_int
(
"e"
);
ck_tile
::
index_t
topk
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
topk
=
arg_parser
.
get_int
(
"k"
);
ck_tile
::
index_t
hidden_size
=
arg_parser
.
get_int
(
"h"
);
ck_tile
::
index_t
hidden_size
=
arg_parser
.
get_int
(
"h"
);
ck_tile
::
index_t
intermediate_size
=
arg_parser
.
get_int
(
"i"
);
ck_tile
::
index_t
intermediate_size
=
arg_parser
.
get_int
(
"i"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
ck_tile
::
index_t
stride
=
arg_parser
.
get_int
(
"stride"
);
ck_tile
::
index_t
block_m
=
arg_parser
.
get_int
(
"bm"
);
ck_tile
::
index_t
block_m
=
arg_parser
.
get_int
(
"bm"
);
if
(
stride
<
0
)
if
(
stride
<
0
)
stride
=
hidden_size
;
stride
=
hidden_size
;
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_i
=
arg_parser
.
get_str
(
"prec_i"
);
std
::
string
prec_w
=
arg_parser
.
get_str
(
"prec_w"
);
std
::
string
prec_w
=
arg_parser
.
get_str
(
"prec_w"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_o
=
arg_parser
.
get_str
(
"prec_o"
);
std
::
string
prec_st
=
arg_parser
.
get_str
(
"prec_st"
);
std
::
string
prec_st
=
arg_parser
.
get_str
(
"prec_st"
);
std
::
string
prec_sw
=
arg_parser
.
get_str
(
"prec_sw"
);
std
::
string
prec_sw
=
arg_parser
.
get_str
(
"prec_sw"
);
std
::
string
prec_sq
=
arg_parser
.
get_str
(
"prec_sq"
);
std
::
string
prec_sq
=
arg_parser
.
get_str
(
"prec_sq"
);
std
::
string
prec_kw
=
arg_parser
.
get_str
(
"prec_kw"
);
std
::
string
prec_kw
=
arg_parser
.
get_str
(
"prec_kw"
);
prec_st
=
(
prec_st
==
"auto"
)
?
"fp32"
:
prec_st
;
prec_st
=
(
prec_st
==
"auto"
)
?
"fp32"
:
prec_st
;
prec_sw
=
(
prec_sw
==
"auto"
)
?
"fp32"
:
prec_sw
;
prec_sw
=
(
prec_sw
==
"auto"
)
?
"fp32"
:
prec_sw
;
prec_sq
=
(
prec_sq
==
"auto"
)
?
"fp32"
:
prec_sq
;
prec_sq
=
(
prec_sq
==
"auto"
)
?
"fp32"
:
prec_sq
;
prec_kw
=
(
prec_kw
==
"auto"
)
?
"fp32"
:
prec_kw
;
prec_kw
=
(
prec_kw
==
"auto"
)
?
"fp32"
:
prec_kw
;
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
kname
=
arg_parser
.
get_int
(
"kname"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
do_validation
=
arg_parser
.
get_int
(
"v"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
int
fused_quant
=
arg_parser
.
get_int
(
"fquant"
);
int
g
only
=
arg_parser
.
get_int
(
"gonly"
);
int
g
ate_only
=
arg_parser
.
get_int
(
"g
ate_
only"
);
int
balance
=
arg_parser
.
get_int
(
"balance"
);
int
balance
=
arg_parser
.
get_int
(
"balance"
);
int
tp
=
arg_parser
.
get_int
(
"tp"
);
int
tp
=
arg_parser
.
get_int
(
"tp"
);
ck_tile
::
index_t
shared_intermediate_size
=
intermediate_size
*
(
gonly
?
1
:
2
)
/
tp
;
ck_tile
::
index_t
shared_intermediate_size
=
intermediate_size
*
(
gate_only
?
1
:
2
)
/
tp
;
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
ADataType
=
typename
TypeConfig
::
ADataType
;
using
ADataType
=
typename
TypeConfig
::
ADataType
;
using
GDataType
=
typename
TypeConfig
::
GDataType
;
using
GDataType
=
typename
TypeConfig
::
GDataType
;
using
DDataType
=
typename
TypeConfig
::
DDataType
;
using
DDataType
=
typename
TypeConfig
::
DDataType
;
using
AccDataType
=
typename
TypeConfig
::
AccDataType
;
using
AccDataType
=
typename
TypeConfig
::
AccDataType
;
using
ODataType
=
typename
TypeConfig
::
ODataType
;
using
ODataType
=
typename
TypeConfig
::
ODataType
;
using
AScaleDataType
=
typename
TypeConfig
::
AScaleDataType
;
using
AScaleDataType
=
typename
TypeConfig
::
AScaleDataType
;
using
W0
ScaleDataType
=
typename
TypeConfig
::
W0
ScaleDataType
;
using
G
ScaleDataType
=
typename
TypeConfig
::
G
ScaleDataType
;
using
W1
ScaleDataType
=
typename
TypeConfig
::
W1
ScaleDataType
;
using
D
ScaleDataType
=
typename
TypeConfig
::
D
ScaleDataType
;
using
YSmoothScaleDataType
=
typename
TypeConfig
::
YSmoothScaleDataType
;
using
YSmoothScaleDataType
=
typename
TypeConfig
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
TypeConfig
::
TopkWeightDataType
;
using
TopkWeightDataType
=
typename
TypeConfig
::
TopkWeightDataType
;
using
IndexDataType
=
typename
TypeConfig
::
IndexDataType
;
using
IndexDataType
=
typename
TypeConfig
::
IndexDataType
;
// host verify
// host verify
ck_tile
::
HostTensor
<
ADataType
>
a_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
ADataType
>
a_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
ADataType
>
g_host
({
e
,
shared_intermediate_size
,
hidden_size
});
ck_tile
::
HostTensor
<
GDataType
>
g_host
({
e
,
shared_intermediate_size
,
hidden_size
});
ck_tile
::
HostTensor
<
ADataType
>
d_host
({
e
,
intermediate_size
,
hidden_size
});
ck_tile
::
HostTensor
<
DDataType
>
d_host
({
e
,
intermediate_size
,
hidden_size
});
ck_tile
::
HostTensor
<
ODataType
>
o_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
AScaleDataType
>
sa_host
({
tokens
});
ck_tile
::
HostTensor
<
XResidualDataType
>
x_residual_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
GScaleDataType
>
sg_host
({
shared_intermediate_size
});
ck_tile
::
HostTensor
<
YResidualDataType
>
y_residual_host
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
DScaleDataType
>
sd_host
({
intermediate_size
});
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>
sy_host
({
intermediate_size
});
// smooth-quant
ck_tile
::
HostTensor
<
YDataType
>
y_host_ref
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
IndexDataType
>
topk_ids_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
YDataType
>
y_host_dev
({
m
,
n
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
TopkWeightDataType
>
topk_weight_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
MeanDataType
>
mean_host_ref
({
m
});
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
(
block_m
-
1
);
ck_tile
::
HostTensor
<
InvStdDataType
>
invStd_host_ref
({
m
});
ck_tile
::
HostTensor
<
IndexDataType
>
sorted_token_ids_host
({
max_num_tokens_padded
});
ck_tile
::
HostTensor
<
YScaleDataType
>
y_scale_host_ref
({
m
});
ck_tile
::
HostTensor
<
TopkWeightDataType
>
sorted_weight_host
({
max_num_tokens_padded
});
ck_tile
::
HostTensor
<
YScaleDataType
>
y_scale_host_dev
({
m
});
ck_tile
::
HostTensor
<
IndexDataType
>
sorted_expert_ids_host
(
{(
max_num_tokens_padded
+
block_m
-
1
)
/
block_m
});
ck_tile
::
HostTensor
<
XScaleDataType
>
x_scale_host
({
n
});
ck_tile
::
HostTensor
<
IndexDataType
>
num_sorted_tiles_host
({
1
});
ck_tile
::
HostTensor
<
XScaleDataType
>
x_scale_host_dev
({
n
});
// permute weight
ck_tile
::
HostTensor
<
GDataType
>
g_perm_host
=
shuffle_moe_weight
(
g_host
,
prec_w
);
ck_tile
::
HostTensor
<
DDataType
>
d_perm_host
=
shuffle_moe_weight
(
d_host
,
prec_w
);
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
.5
f
,
.5
f
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
.5
f
,
.5
f
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
XResidualDataType
>
{
-
.5
f
,
.5
f
}(
x_residual_host
);
ck_tile
::
FillUniformDistribution
<
GDataType
>
{
-
.5
f
,
.5
f
}(
g_perm_host
);
ck_tile
::
FillUniformDistribution
<
XScaleDataType
>
{
-
1.
f
,
1.
f
}(
x_scale_host
);
ck_tile
::
FillUniformDistribution
<
DDataType
>
{
-
.5
f
,
.5
f
}(
d_perm_host
);
ck_tile
::
FillUniformDistribution
<
GammaDataType
>
{
-
.5
f
,
.5
f
}(
gamma_host
);
ck_tile
::
FillUniformDistribution
<
AScaleDataType
>
{
-
.5
f
,
.5
f
}(
sa_host
);
ck_tile
::
FillUniformDistribution
<
BetaDataType
>
{
-
.5
f
,
.5
f
}(
beta_host
);
ck_tile
::
FillUniformDistribution
<
GScaleDataType
>
{
-
.5
f
,
.5
f
}(
sg_host
);
ck_tile
::
FillUniformDistribution
<
DScaleDataType
>
{
-
.5
f
,
.5
f
}(
sd_host
);
ck_tile
::
DeviceMem
x_buf
(
a_host
.
get_element_space_size_in_bytes
());
ck_tile
::
FillUniformDistribution
<
YSmoothScaleDataType
>
{
-
.5
f
,
.5
f
}(
sy_host
);
ck_tile
::
DeviceMem
gamma_buf
(
gamma_host
.
get_element_space_size_in_bytes
());
ck_tile
::
FillUniformDistribution
<
TopkWeightDataType
>
{
-
.5
f
,
.5
f
}(
topk_weight_host
);
ck_tile
::
DeviceMem
beta_buf
(
beta_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_buf
(
y_host_dev
.
get_element_space_size_in_bytes
());
// do moe sorting
ck_tile
::
DeviceMem
y_scale_buf
(
y_scale_host_dev
.
get_element_space_size_in_bytes
());
if
(
balance
)
ck_tile
::
DeviceMem
x_scale_buf
(
x_scale_host_dev
.
get_element_space_size_in_bytes
());
{
int
e_cnt
=
0
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
topk_ids_host
.
mData
.
size
());
i
++
)
ck_tile
::
DeviceMem
x_residual_buf
(
x_residual_host
.
get_element_space_size_in_bytes
());
{
ck_tile
::
DeviceMem
y_residual_buf
(
y_residual_host
.
get_element_space_size_in_bytes
());
topk_ids_host
.
mData
[
i
]
=
e_cnt
;
e_cnt
++
;
x_buf
.
ToDevice
(
a_host
.
data
());
if
(
e_cnt
>=
experts
)
gamma_buf
.
ToDevice
(
gamma_host
.
data
());
e_cnt
=
0
;
beta_buf
.
ToDevice
(
beta_host
.
data
());
}
x_residual_buf
.
ToDevice
(
x_residual_host
.
data
());
}
x_scale_buf
.
ToDevice
(
x_scale_host
.
data
());
else
{
topid_unique_gen
<
IndexType
>
(
topk_ids_host
.
mData
,
tokens
,
topk
,
experts
,
11913
);
}
ck_tile
::
reference_moe_sorting
<
TopkWeightDataType
,
IndexDataType
>
(
topk_ids_host
,
topk_weight_host
,
sorted_token_ids_host
,
sorted_weight_host
,
sorted_expert_ids_host
,
num_sorted_tiles_host
.
mData
[
0
],
experts
,
block_m
);
// done, preparing GPU buffer
ck_tile
::
DeviceMem
a_buf
(
a_host
);
ck_tile
::
DeviceMem
g_perm_buf
(
g_perm_host
);
ck_tile
::
DeviceMem
d_perm_buf
(
d_perm_host
);
ck_tile
::
DeviceMem
sa_buf
(
sa_host
);
ck_tile
::
DeviceMem
sg_buf
(
sg_host
);
ck_tile
::
DeviceMem
sd_buf
(
sd_host
);
ck_tile
::
DeviceMem
sy_buf
(
sy_host
);
ck_tile
::
DeviceMem
o_buf
(
o_host
);
ck_tile
::
DeviceMem
sorted_token_ids_buf
(
sorted_token_ids_host
);
ck_tile
::
DeviceMem
sorted_weight_buf
(
sorted_weight_host
);
ck_tile
::
DeviceMem
sorted_expert_ids_buf
(
sorted_expert_ids_host
);
ck_tile
::
DeviceMem
num_sorted_tiles_buf
(
num_sorted_tiles_host
);
auto
prec_str
=
[
&
]()
{
auto
prec_str
=
[
&
]()
{
auto
base_str
=
prec_i
;
auto
base_str
=
prec_i
;
if
(
prec_i
!=
prec_w
)
base_str
+=
"x"
+
prec_w
;
if
(
prec_i
!=
prec_o
)
if
(
prec_i
!=
prec_o
)
base_str
+=
"="
+
prec_o
;
if
(
fused_quant
!=
0
)
{
{
base_str
+=
"|"
+
prec_o
;
base_str
+=
std
::
string
(
"("
)
+
prec_sa
+
"|"
+
prec_sg
+
"|"
+
prec_sq
+
")"
;
}
if
(
fused_quant
==
1
)
{
base_str
+=
std
::
string
(
"("
)
+
prec_sy
+
")"
;
}
}
return
base_str
;
return
base_str
;
}();
}();
std
::
cout
<<
"["
<<
prec_str
<<
"]"
std
::
cout
<<
"["
<<
prec_str
<<
"]"
<<
" m:"
<<
m
<<
", n:"
<<
n
<<
", stride:"
<<
stride
<<
std
::
flush
;
<<
" t:"
<<
tokens
<<
", e:"
<<
experts
<<
", k:"
<<
topk
<<
", st:"
<<
stride
<<
", hidden:"
<<
hidden_size
<<
", interm:"
<<
intermediate_size
<<
", tp:"
<<
tp
layernorm2d_fwd_traits
traits
{
<<
", go:"
<<
gate_only
<<
", q:"
<<
fused_quant
<<
std
::
flush
;
prec_i
,
prec_o
,
prec_sx
,
prec_sy
,
SaveMeanVar
,
fused_add
,
fused_quant
};
fused_moegemm_traits
traits
{
prec_i
,
layernorm2d_fwd_args
args
{
x_buf
.
GetDeviceBuffer
(),
prec_w
,
fused_add
!=
0
?
x_residual_buf
.
GetDeviceBuffer
()
:
nullptr
,
prec_o
,
fused_quant
==
1
?
x_scale_buf
.
GetDeviceBuffer
()
:
nullptr
,
prec_st
,
gamma_buf
.
GetDeviceBuffer
(),
prec_sw
,
beta_buf
.
GetDeviceBuffer
(),
prec_sq
,
prec_kw
,
y_buf
.
GetDeviceBuffer
(),
block_m
,
fused_add
==
1
?
y_residual_buf
.
GetDeviceBuffer
()
:
nullptr
,
gate_only
,
fused_quant
!=
0
?
y_scale_buf
.
GetDeviceBuffer
()
:
nullptr
,
fused_quant
};
nullptr
,
// p_mean, unsupported yet
nullptr
,
// p_invStd, unsupported yet
fused_moegemm_args
args
{
a_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
?
sa_buf
.
GetDeviceBuffer
()
:
nullptr
,
epsilon
,
g_buf
.
GetDeviceBuffer
(),
m
,
d_buf
.
GetDeviceBuffer
(),
n
,
fused_quant
!=
0
stride
};
?
sg_buf
.
GetDeviceBuffer
(),
fused_quant
!=
0
float
ave_time
=
layernorm2d_fwd
(
?
sd_buf
.
GetDeviceBuffer
(),
fused_quant
==
1
?
sy_buf
.
GetDeviceBuffer
(),
o_buf
.
GetDeviceBuffer
(),
sorted_token_ids_buf
.
GetDeviceBuffer
(),
sorted_weight_buf
.
GetDeviceBuffer
(),
sorted_expert_ids_buf
.
GetDeviceBuffer
(),
num_sorted_tiles_buf
.
GetDeviceBuffer
(),
hidden_size
,
intermediate_size
,
num_tokens
,
experts
,
stride
};
float
ave_time
=
fused_moegemm
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
if
(
ave_time
<
0
)
if
(
ave_time
<
0
)
...
@@ -199,22 +296,30 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -199,22 +296,30 @@ bool run(const ck_tile::ArgParser& arg_parser)
return
false
;
return
false
;
}
}
#if 0
std::size_t num_byte = sizeof(ADataType) * m * n + sizeof(GammaDataType) * n +
std::size_t num_byte = sizeof(ADataType) * m * n + sizeof(GammaDataType) * n +
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n;
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
#else
std
::
size_t
flop_gemm_0
=
2
*
tokens
*
topk
*
shared_intermediate_size
*
hidden_size
;
std
::
size_t
flop_gemm_1
=
2
*
tokens
*
topk
*
hidden_size
*
hidden_size
;
double
tflops
=
(
flop_gemm_0
+
flop_gemm_1
)
/
(
static_cast
<
double
>
(
ave_time
)
*
1e-3
)
/
1e12
;
// float gb_per_sec = num_byte / 1.E6 / ave_time;
std
::
cout
<<
", "
<<
ave_time
*
1.E3
<<
" us, "
<<
tflops
<<
" tflops"
<<
std
::
flush
;
#endif
bool
pass
=
true
;
bool
pass
=
true
;
if
(
do_validation
)
if
(
do_validation
)
{
{
#if 0
// reference
// reference
if(fused_add != 0)
if(fused_add != 0)
{
{
// fused pre_add/pre_add_store
// fused pre_add/pre_add_store
// TODO we accumulate directly to a_host for simplcity here...
// TODO we accumulate directly to a_host for simplcity here...
std::transform(a_host.mData.cbegin(),
std::transform(a_host.mData.cbegin(),
a_host.mData.cend(),
a_host.mData.cend(),
x_residual_host.mData.cbegin(),
x_residual_host.mData.cbegin(),
...
@@ -353,6 +458,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -353,6 +458,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
#else
std
::
cout
<<
std
::
flush
<<
std
::
endl
;
#endif
}
}
return
pass
;
return
pass
;
...
...
include/ck_tile/host.hpp
View file @
70fa98ad
...
@@ -23,6 +23,7 @@
...
@@ -23,6 +23,7 @@
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/reference/reference_permute.hpp"
#include "ck_tile/host/reference/reference_permute.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
...
...
include/ck_tile/host/device_memory.hpp
View file @
70fa98ad
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include <stdint.h>
#include <stdint.h>
#include <stdexcept>
#include <stdexcept>
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
T
>
template
<
typename
T
>
...
@@ -36,6 +37,19 @@ struct DeviceMem
...
@@ -36,6 +37,19 @@ struct DeviceMem
mpDeviceBuf
=
nullptr
;
mpDeviceBuf
=
nullptr
;
}
}
}
}
template
<
T
>
DeviceMem
(
const
HostTensor
<
T
>&
t
)
:
mMemSize
(
t
.
get_element_space_size_in_bytes
())
{
if
(
mMemSize
!=
0
)
{
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
}
else
{
mpDeviceBuf
=
nullptr
;
}
ToDevice
(
t
.
data
());
}
void
Realloc
(
std
::
size_t
mem_size
)
void
Realloc
(
std
::
size_t
mem_size
)
{
{
if
(
mpDeviceBuf
)
if
(
mpDeviceBuf
)
...
@@ -92,6 +106,22 @@ struct DeviceMem
...
@@ -92,6 +106,22 @@ struct DeviceMem
HIP_CHECK_ERROR
(
hipMemcpy
(
p
,
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
HIP_CHECK_ERROR
(
hipMemcpy
(
p
,
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
}
}
}
}
// construct a host tensor with type T
template
<
typename
T
>
HostTensor
<
T
>
ToHost
(
std
::
size_t
cpySize
=
mMemSize
)
{
// TODO: host tensor could be slightly larger than the device tensor
// we just copy all data from GPU buffer
std
::
size_t
host_elements
=
(
cpySize
+
sizeof
(
T
)
-
1
)
/
sizeof
(
T
)
HostTensor
<
T
>
h_
({
host_elements
});
if
(
mpDeviceBuf
)
{
HIP_CHECK_ERROR
(
hipMemcpy
(
h_
.
data
(),
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
}
return
h_
;
}
void
SetZero
()
const
void
SetZero
()
const
{
{
if
(
mpDeviceBuf
)
if
(
mpDeviceBuf
)
...
...
include/ck_tile/host/reference/reference_moe_sorting.hpp
0 → 100644
View file @
70fa98ad
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
template
<
typename
WeightType
,
typename
IndexType
=
index_t
>
CK_TILE_HOST
void
reference_moe_sorting
(
const
HostTensor
<
IndexType
>&
topk_ids
,
const
HostTensor
<
WeightType
>&
weights
,
HostTensor
<
IndexType
>&
sorted_token_ids
,
HostTensor
<
WeightType
>&
sorted_weight
,
HostTensor
<
IndexType
>&
sorted_expert_ids
,
index_t
&
unit_cnt
,
const
index_t
experts
,
const
index_t
unit_size
)
{
const
index_t
num_token
=
topk_ids
.
mDesc
.
get_lengths
()[
0
];
const
index_t
topk
=
topk_ids
.
mDesc
.
get_lengths
()[
1
];
std
::
vector
<
std
::
vector
<
IndexType
>>
expert_tokens
(
experts
,
std
::
vector
<
IndexType
>
(
unit_size
,
num_token
));
std
::
vector
<
std
::
vector
<
WeightType
>>
expert_token_weights
(
experts
,
std
::
vector
<
WeightType
>
(
unit_size
,
0
));
std
::
vector
<
IndexType
>
expert_slices
(
experts
,
1
);
std
::
vector
<
IndexType
>
expert_slice_idxs
(
experts
,
0
);
for
(
index_t
t
=
0
;
t
<
num_token
;
t
++
)
{
for
(
index_t
k
=
0
;
k
<
topk
;
k
++
)
{
IndexType
e
=
topk_ids
(
t
,
k
);
WeightType
w
=
weights
(
t
,
k
);
index_t
idx
=
expert_slice_idxs
[
e
];
if
(
idx
>
expert_slices
[
e
]
*
unit_size
-
1
)
{
expert_slices
[
e
]
++
;
index_t
new_size
=
expert_slices
[
e
]
*
unit_size
;
expert_tokens
[
e
].
resize
(
new_size
);
expert_token_weights
[
e
].
resize
(
new_size
);
for
(
index_t
i
=
(
expert_slices
[
e
]
-
1
)
*
unit_size
;
i
<
new_size
;
i
++
)
{
expert_tokens
[
e
][
i
]
=
num_token
;
expert_token_weights
[
e
][
i
]
=
0
;
}
}
expert_tokens
[
e
][
idx
]
=
t
;
expert_token_weights
[
e
][
idx
]
=
w
;
expert_slice_idxs
[
e
]
++
;
}
}
IndexType
*
out_tokens
=
sorted_token_ids
.
data
();
WeightType
*
out_weights
=
sorted_weight
.
data
();
IndexType
*
out_expert_id
=
sorted_expert_ids
.
data
();
for
(
index_t
e
=
0
;
e
<
experts
;
e
++
)
{
memcpy
(
out_tokens
,
expert_tokens
[
e
].
data
(),
sizeof
(
index_t
)
*
expert_slices
[
e
]
*
unit_size
);
out_tokens
+=
expert_slices
[
e
]
*
unit_size
;
memcpy
(
out_weights
,
expert_token_weights
[
e
].
data
(),
sizeof
(
WeightType
)
*
expert_slices
[
e
]
*
unit_size
);
out_weights
+=
expert_slices
[
e
]
*
unit_size
;
for
(
index_t
s
=
0
;
s
<
expert_slices
[
e
];
s
++
)
{
out_expert_id
[
s
]
=
e
;
unit_cnt
++
;
}
out_expert_id
+=
expert_slices
[
e
];
}
return
;
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_permute.hpp
View file @
70fa98ad
...
@@ -56,11 +56,10 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
...
@@ -56,11 +56,10 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
}
}
template
<
typename
DataType
>
template
<
typename
DataType
>
CK_TILE_HOST
auto
CK_TILE_HOST
auto
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
std
::
vector
<
index_t
>
perm
)
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
std
::
vector
<
index_t
>
perm
)
{
{
auto
x_shape
=
x
.
get_lengths
();
auto
x_shape
=
x
.
get_lengths
();
ck_tile
::
index_t
rank
=
perm
.
size
();
ck_tile
::
index_t
rank
=
perm
.
size
();
std
::
vector
<
ck_tile
::
index_t
>
y_shape
=
[
&
]()
{
std
::
vector
<
ck_tile
::
index_t
>
y_shape
=
[
&
]()
{
std
::
vector
<
ck_tile
::
index_t
>
tmp
(
rank
,
0
);
std
::
vector
<
ck_tile
::
index_t
>
tmp
(
rank
,
0
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
rank
);
i
++
)
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
rank
);
i
++
)
...
...
include/ck_tile/ops/fused_moe.hpp
View file @
70fa98ad
...
@@ -3,12 +3,12 @@
...
@@ -3,12 +3,12 @@
#pragma once
#pragma once
#include "ck_tile/ops/fused_moe/kernel/fused_moe_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moe
gemm
_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moe_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moe
gemm
_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moe_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moe
gemm
_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_flatmm.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe
gemm
_pipeline_flatmm.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe
gemm
_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe
gemm
_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe
gemm
_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
70fa98ad
...
@@ -22,17 +22,17 @@
...
@@ -22,17 +22,17 @@
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
//
// max_tokens_
post_
padded : top_k * input_tokens + num_experts * (M_a - 1)
// max_
num_
tokens_padded : top_k * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
// * this could be larger than actual, since actual tokens are on GPU
//
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
//
// * length is max_tokens_
post_
padded, actual size is num_tokens_post_padded_ptr
// * length is max_
num_
tokens_padded, actual size is num_tokens_post_padded_ptr
//
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_tokens_
post_
padded + block_size - 1) / block_size
// * length is (max_
num_
tokens_padded + block_size - 1) / block_size
//
//
// num_tokens_post_padded_ptr : [28]
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
// num_sorted_tiles_ptr : [7]
...
@@ -43,11 +43,12 @@
...
@@ -43,11 +43,12 @@
// 3) use num_sorted_tiles_ptr, already divided by M_a
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
//
// * below used for indexing
// * below used for indexing
// 1) sorted_token_ids_ptr
// 1) sorted_token_ids_ptr
[max_num_tokens_padded]
// 2) sorted_weight_ptr
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
//
//
// [indexing implementation-2]
// [indexing implementation-2]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
...
@@ -92,15 +93,15 @@ struct FusedMoeGemmHostArgs
...
@@ -92,15 +93,15 @@ struct FusedMoeGemmHostArgs
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
void
*
o_ptr
;
// [m, k], output token
void
*
o_ptr
;
// [m, k], output token
const
void
*
sorted_token_ids_ptr
;
const
void
*
sorted_token_ids_ptr
;
// [max_num_tokens_padded]
const
void
*
sorted_weight_ptr
;
const
void
*
sorted_weight_ptr
;
// [max_num_tokens_padded]
const
void
*
sorted_expert_ids_ptr
;
const
void
*
sorted_expert_ids_ptr
;
// [(max_num_tokens_padded + block_size - 1) / block_size]
const
void
*
num_sorted_tiles_ptr
;
const
void
*
num_sorted_tiles_ptr
;
// [1]
index_t
hidden_size
;
// k
index_t
hidden_size
;
// k
index_t
intermediate_size
;
// n (TP slice this)
index_t
intermediate_size
;
// n (TP slice this)
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
index_t
num_experts
;
// number of groups
// index_t top_k; // need this?
// index_t top_k; // need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
...
@@ -134,10 +135,10 @@ struct FusedMoeGemmKernel
...
@@ -134,10 +135,10 @@ struct FusedMoeGemmKernel
using
Traits
=
typename
Pipeline
::
Problem
::
Traits
;
using
Traits
=
typename
Pipeline
::
Problem
::
Traits
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
// clang-format off
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
typename
T
>
struct
t2s
;
...
@@ -173,10 +174,10 @@ struct FusedMoeGemmKernel
...
@@ -173,10 +174,10 @@ struct FusedMoeGemmKernel
const
void
*
sorted_expert_ids_ptr
;
const
void
*
sorted_expert_ids_ptr
;
const
void
*
num_sorted_tiles_ptr
;
const
void
*
num_sorted_tiles_ptr
;
index_t
hidden_size
;
// k
index_t
hidden_size
;
// k
index_t
intermediate_size
;
// n (TP slice this)
index_t
intermediate_size
;
// n (TP slice this)
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
index_t
num_experts
;
// number of groups
// index_t top_k; // need this?
// index_t top_k; // need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
...
@@ -214,7 +215,7 @@ struct FusedMoeGemmKernel
...
@@ -214,7 +215,7 @@ struct FusedMoeGemmKernel
index_t
nr_0
=
kargs
.
intermediate_size
/
Pipeline
::
Block_Nr0
;
index_t
nr_0
=
kargs
.
intermediate_size
/
Pipeline
::
Block_Nr0
;
index_t
kr_0
=
kargs
.
hidden_size
/
Pipeline
::
Block_Kr0
;
index_t
kr_0
=
kargs
.
hidden_size
/
Pipeline
::
Block_Kr0
;
index_t
nr_1
=
kargs
.
hidden_size
/
Pipeline
::
Block_Nr1
;
// should be same as kr_0
index_t
nr_1
=
kargs
.
hidden_size
/
Pipeline
::
Block_Nr1
;
// should be same as kr_0
index_t
kr_1
=
kargs
.
intermediate_size
/
Pipeline
::
Block_Kr1
;
// should be same as nr_0
index_t
kr_1
=
kargs
.
intermediate_size
/
Pipeline
::
Block_Kr1
;
// should be same as nr_0
index_t
expert_stride_0
=
kargs
.
intermediate_size
*
hidden_radio_0
*
kargs
.
hidden_size
;
index_t
expert_stride_0
=
kargs
.
intermediate_size
*
hidden_radio_0
*
kargs
.
hidden_size
;
...
@@ -280,11 +281,12 @@ struct FusedMoeGemmKernel
...
@@ -280,11 +281,12 @@ struct FusedMoeGemmKernel
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
Pipeline
::
Block_W0
>
{},
1
),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
Pipeline
::
Block_W0
>
{},
1
),
number
<
Pipeline
::
kAlignmentG
>
{},
number
<
Pipeline
::
kAlignmentG
>
{},
number
<
1
>
{});
number
<
1
>
{});
const
auto
g_view_1_
=
pad_tensor_view
(
g_view_
,
const
auto
g_view_1_
=
make_tuple
(
number
<
Pipeline
::
Block_Nr0
>
{},
pad_tensor_view
(
g_view_
,
number
<
Pipeline
::
Block_Kr0
>
{},
make_tuple
(
number
<
Pipeline
::
Block_Nr0
>
{},
number
<
Pipeline
::
Block_W0
>
{}),
number
<
Pipeline
::
Block_Kr0
>
{},
sequence
<
PadIntermediateSize
,
PadHiddenSize
,
0
>
{});
number
<
Pipeline
::
Block_W0
>
{}),
sequence
<
PadIntermediateSize
,
PadHiddenSize
,
0
>
{});
const
auto
g_window_
=
make_tile_window
(
g_view_1_
,
const
auto
g_window_
=
make_tile_window
(
g_view_1_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
...
@@ -308,11 +310,12 @@ struct FusedMoeGemmKernel
...
@@ -308,11 +310,12 @@ struct FusedMoeGemmKernel
make_tuple
(
kr_1
*
Pipeline
::
Block_W1
,
Pipeline
::
Block_W1
,
1
),
make_tuple
(
kr_1
*
Pipeline
::
Block_W1
,
Pipeline
::
Block_W1
,
1
),
number
<
Pipeline
::
kAlignmentD
>
{},
number
<
Pipeline
::
kAlignmentD
>
{},
number
<
1
>
{});
number
<
1
>
{});
const
auto
d_view_1_
=
pad_tensor_view
(
d_view_
,
const
auto
d_view_1_
=
make_tuple
(
number
<
Pipeline
::
kBlockNr_1
>
{},
pad_tensor_view
(
d_view_
,
number
<
Pipeline
::
kBlockKr_1
>
{},
make_tuple
(
number
<
Pipeline
::
kBlockNr_1
>
{},
number
<
Pipeline
::
Block_W1
>
{}),
number
<
Pipeline
::
kBlockKr_1
>
{},
sequence
<
PadHiddenSize
,
PadIntermediateSize
,
0
>
{});
number
<
Pipeline
::
Block_W1
>
{}),
sequence
<
PadHiddenSize
,
PadIntermediateSize
,
0
>
{});
const
auto
d_window_
=
make_tile_window
(
d_view_1_
,
const
auto
d_window_
=
make_tile_window
(
d_view_1_
,
make_tuple
(
number
<
Pipeline
::
kBlockNr_1
>
{},
make_tuple
(
number
<
Pipeline
::
kBlockNr_1
>
{},
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
View file @
70fa98ad
...
@@ -44,10 +44,10 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -44,10 +44,10 @@ struct FusedMoeGemmPipeline_Flatmm
using
Traits
=
typename
Pipeline
::
Problem
::
Traits
;
using
Traits
=
typename
Pipeline
::
Problem
::
Traits
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
static
constexpr
index_t
kAlignmentA
=
Policy
::
GetAlignment_A
<
Problem
>
();
static
constexpr
index_t
kAlignmentA
=
Policy
::
GetAlignment_A
<
Problem
>
();
static
constexpr
index_t
kAlignmentG
=
Policy
::
GetAlignment_G
<
Problem
>
();
static
constexpr
index_t
kAlignmentG
=
Policy
::
GetAlignment_G
<
Problem
>
();
...
@@ -133,11 +133,12 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -133,11 +133,12 @@ struct FusedMoeGemmPipeline_Flatmm
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
BlockShape
::
Block_W0
>
{},
1
),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
BlockShape
::
Block_W0
>
{},
1
),
number
<
kAlignmentG
>
{},
number
<
kAlignmentG
>
{},
number
<
1
>
{});
number
<
1
>
{});
const
auto
u_view_1_
=
pad_tensor_view
(
u_view_
,
const
auto
u_view_1_
=
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
pad_tensor_view
(
u_view_
,
number
<
BlockShape
::
Block_Kr0
>
{},
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
number
<
BlockShape
::
Block_Kr0
>
{},
sequence
<
PadIntermediateSize
,
PadHiddenSize
,
0
>
{});
number
<
BlockShape
::
Block_W0
>
{}),
sequence
<
PadIntermediateSize
,
PadHiddenSize
,
0
>
{});
return
u_view_1_
;
return
u_view_1_
;
}
}
}();
}();
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
View file @
70fa98ad
...
@@ -225,7 +225,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
...
@@ -225,7 +225,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMatrixCoreSwizzledBlockTIle_0
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMatrixCoreSwizzledBlockTIle_0
()
{
{
if
constexpr
(
Problem
::
Traits
::
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
if
constexpr
(
Problem
::
Traits
::
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
{
{
using
WarpGemm
=
GetWarpGemm0
<
Problem
>
{};
// assume warpgemm0/1 are the same
using
WarpGemm
=
GetWarpGemm0
<
Problem
>
{};
// assume warpgemm0/1 are the same
constexpr
index_t
NPerBlock
=
Problem
::
BlockShape
::
Block_N0
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockShape
::
Block_N0
;
...
@@ -703,7 +704,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
...
@@ -703,7 +704,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMatrixCoreSwizzledBlockTIle_0
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMatrixCoreSwizzledBlockTIle_0
()
{
{
if
constexpr
(
Problem
::
Traits
::
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
if
constexpr
(
Problem
::
Traits
::
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
{
{
using
WarpGemm
=
GetWarpGemm0
<
Problem
>
{};
// assume warpgemm0/1 are the same
using
WarpGemm
=
GetWarpGemm0
<
Problem
>
{};
// assume warpgemm0/1 are the same
constexpr
index_t
NPerBlock
=
Problem
::
BlockShape
::
Block_N0
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockShape
::
Block_N0
;
...
@@ -723,7 +725,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
...
@@ -723,7 +725,8 @@ struct FusedMoeGemmPipelineFlatmmPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMatrixCoreSwizzledBlockTIle_1
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMatrixCoreSwizzledBlockTIle_1
()
{
{
if
constexpr
(
Problem
::
Traits
::
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
if
constexpr
(
Problem
::
Traits
::
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
{
{
using
WarpGemm
=
GetWarpGemm1
<
Problem
>
{};
// assume warpgemm0/1 are the same
using
WarpGemm
=
GetWarpGemm1
<
Problem
>
{};
// assume warpgemm0/1 are the same
constexpr
index_t
NPerBlock
=
Problem
::
BlockShape
::
kBlockN_1
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockShape
::
kBlockN_1
;
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
View file @
70fa98ad
...
@@ -14,8 +14,8 @@ template <typename ADataType_,
...
@@ -14,8 +14,8 @@ template <typename ADataType_,
typename
AccDataType_
,
typename
AccDataType_
,
typename
ODataType_
,
typename
ODataType_
,
typename
AScaleDataType_
,
typename
AScaleDataType_
,
typename
W0
ScaleDataType_
,
typename
G
ScaleDataType_
,
typename
W1
ScaleDataType_
,
typename
D
ScaleDataType_
,
typename
YSmoothScaleDataType_
,
typename
YSmoothScaleDataType_
,
typename
TopkWeightDataType_
,
typename
TopkWeightDataType_
,
typename
IndexDataType_
,
// data type for all indexing
typename
IndexDataType_
,
// data type for all indexing
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
View file @
70fa98ad
...
@@ -19,14 +19,18 @@ enum class FusedMoeGemmWeightPermuteEnum
...
@@ -19,14 +19,18 @@ enum class FusedMoeGemmWeightPermuteEnum
template
<
bool
IsGateOnly_
,
template
<
bool
IsGateOnly_
,
bool
UseSmoothQuant_
,
bool
UseSmoothQuant_
,
index_t
OAtomic_
,
// 0-no atomic, 1-atomic-pk-f16/bf16, 2-atomic-f32
index_t
OAtomic_
,
// 0-no atomic, 1-atomic-pk-f16/bf16, 2-atomic-f32
FusedMoeGemmWeightPermuteEnum
PermuteEnum_
=
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
;
FusedMoeGemmWeightPermuteEnum
PermuteEnum_
=
bool
PadHiddenSize_
=
false
,
bool
PadIntermediateSize_
=
false
>
struct
FusedMoeGemmTraits
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
,
bool
PadHiddenSize_
=
false
,
bool
PadIntermediateSize_
=
false
>
struct
FusedMoeGemmTraits
{
{
// Gate+Up or Gate only
// Gate+Up or Gate only
static
constexpr
bool
IsGateOnly
=
IsGateOnly_
;
static
constexpr
bool
IsGateOnly
=
IsGateOnly_
;
static
constexpr
bool
UseSmoothQuant
=
UseSmoothQuant_
;
static
constexpr
bool
UseSmoothQuant
=
UseSmoothQuant_
;
static
constexpr
index_t
OAtomic
=
OAtomic_
;
static
constexpr
index_t
OAtomic
=
OAtomic_
;
static
constexpr
bool
PadHiddenSize
=
PadHiddenSize_
;
static
constexpr
FusedMoeGemmWeightPermuteEnum
PermuteEnum
=
PermuteEnum_
;
static
constexpr
bool
PadIntermediateSize
=
PadIntermediateSize_
;
static
constexpr
bool
PadHiddenSize
=
PadHiddenSize_
;
static
constexpr
bool
PadIntermediateSize
=
PadIntermediateSize_
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment