Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b616b254
"...composable_kernel_rocm.git" did not exist on "39d92e7dfdb2893a0e7d0521523c442ec403712c"
Commit
b616b254
authored
Dec 05, 2024
by
letaoqin
Browse files
add debuging code and format
parent
2baf9422
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
185 additions
and
148 deletions
+185
-148
example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api_traits.hpp
..._fused_moe_general/instances/fused_moegemm_api_traits.hpp
+1
-1
example/ck_tile/17_fused_moe_general/main.cpp
example/ck_tile/17_fused_moe_general/main.cpp
+146
-112
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
...ile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
+2
-2
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+1
-1
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+10
-9
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+21
-19
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
+4
-4
No files found.
example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api_traits.hpp
View file @
b616b254
...
@@ -45,7 +45,7 @@ struct fmoe_ // traits, ugly name, only used for internal
...
@@ -45,7 +45,7 @@ struct fmoe_ // traits, ugly name, only used for internal
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
>
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
>
;
using
WarpPerBlock_1
=
ck_tile
::
sequence
<
1
,
1
,
4
>
;
//ck_tile::remove_cvref_t<WarpPerBlock_>;
using
WarpPerBlock_1
=
ck_tile
::
sequence
<
1
,
1
,
4
>
;
//
ck_tile::remove_cvref_t<WarpPerBlock_>;
using
WarpTile_1
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
using
WarpTile_1
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
static
constexpr
ck_tile
::
index_t
GateOnly
=
GateOnly_
;
static
constexpr
ck_tile
::
index_t
GateOnly
=
GateOnly_
;
...
...
example/ck_tile/17_fused_moe_general/main.cpp
View file @
b616b254
...
@@ -83,13 +83,43 @@ void topid_unique_gen(
...
@@ -83,13 +83,43 @@ void topid_unique_gen(
host_tensor
[
i
]
=
current_v
;
host_tensor
[
i
]
=
current_v
;
}
}
}
}
template
<
typename
IndexType
>
void
output_matrix_2d
(
ck_tile
::
HostTensor
<
IndexType
>&
data
,
int
m
,
int
n
)
{
std
::
cout
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
m
;
i
++
)
{
std
::
cout
<<
"Line "
<<
i
<<
"
\t
"
;
for
(
int
j
=
0
;
j
<
n
;
j
++
)
{
std
::
cout
<<
ck_tile
::
type_convert
<
float
>
(
data
(
i
,
j
))
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
}
}
template
<
typename
IndexType
>
void
output_matrix_3d
(
ck_tile
::
HostTensor
<
IndexType
>&
data
,
int
M
,
int
N
,
int
J
)
{
std
::
cout
<<
std
::
endl
;
for
(
int
m
=
0
;
m
<
M
;
m
++
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
std
::
cout
<<
"experts: "
<<
m
<<
" Line: "
<<
n
<<
"
\t
"
;
for
(
int
j
=
0
;
j
<
J
;
j
++
)
{
std
::
cout
<<
ck_tile
::
type_convert
<
float
>
(
data
(
m
,
n
,
j
))
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
}
}
}
auto
create_args
(
int
argc
,
char
*
argv
[])
auto
create_args
(
int
argc
,
char
*
argv
[])
{
{
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"t"
,
"128"
,
"num input tokens"
)
arg_parser
.
insert
(
"t"
,
"128"
,
"num input tokens"
)
.
insert
(
"e"
,
"32"
,
"num of experts"
)
.
insert
(
"e"
,
"32"
,
"num of experts"
)
.
insert
(
"k"
,
"
5
"
,
"topk"
)
.
insert
(
"k"
,
"
2
"
,
"topk"
)
.
insert
(
"h"
,
"8192"
,
"hidden_size of this model"
)
.
insert
(
"h"
,
"8192"
,
"hidden_size of this model"
)
.
insert
(
"i"
,
"8192"
,
"intermediate_size between 2 gemms of FFN"
)
.
insert
(
"i"
,
"8192"
,
"intermediate_size between 2 gemms of FFN"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to hidden_size"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to hidden_size"
)
...
@@ -112,7 +142,7 @@ auto create_args(int argc, char* argv[])
...
@@ -112,7 +142,7 @@ auto create_args(int argc, char* argv[])
"0"
,
"0"
,
"if set to 1, will try balance the expert in topk-ids(convenient for testing)"
)
"if set to 1, will try balance the expert in topk-ids(convenient for testing)"
)
.
insert
(
"init"
,
.
insert
(
"init"
,
"
2
"
,
"
1
"
,
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized"
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized"
"normalized(slow)"
)
"normalized(slow)"
)
.
insert
(
"seed"
,
"11939"
,
"seed used to do random"
)
.
insert
(
"seed"
,
"11939"
,
"seed used to do random"
)
...
@@ -176,9 +206,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -176,9 +206,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
return
base_str
;
return
base_str
;
}();
}();
auto
api_str
=
[
&
]()
{
auto
api_str
=
[
&
]()
{
return
std
::
string
(
"moeg"
);
}();
return
std
::
string
(
"moeg"
);
}();
auto
stride_str
=
[
&
]()
{
auto
stride_str
=
[
&
]()
{
if
(
stride
==
hidden_size
)
if
(
stride
==
hidden_size
)
...
@@ -245,7 +273,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -245,7 +273,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
FillUniformDistribution
<
GScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sg_host
);
ck_tile
::
FillUniformDistribution
<
GScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sg_host
);
ck_tile
::
FillUniformDistribution
<
DScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sd_host
);
ck_tile
::
FillUniformDistribution
<
DScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sd_host
);
ck_tile
::
FillUniformDistribution
<
YSmoothScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sy_host
);
ck_tile
::
FillUniformDistribution
<
YSmoothScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sy_host
);
ck_tile
::
FillUniformDistribution
<
TopkWeightDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
ck_tile
::
FillUniformDistribution
<
TopkWeightDataType
>
{
0.0
f
,
1.0
f
,
seed
,
true
}(
topk_weight_host
);
topk_weight_host
);
}
}
else
if
(
init
==
2
)
else
if
(
init
==
2
)
...
@@ -343,6 +371,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -343,6 +371,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
experts
,
experts
,
block_m
);
block_m
);
// output_matrix_2d(a_host, tokens, hidden_size);
std
::
cout
<<
sorted_token_ids_host
<<
std
::
endl
;
std
::
cout
<<
num_sorted_tiles_host
<<
std
::
endl
;
// output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size);
std
::
cout
<<
sorted_expert_ids_host
<<
std
::
endl
;
// std::cout << topk_weight_host << std::endl;
// std::cout << sorted_weight_host << std::endl;
// done, preparing GPU buffer
// done, preparing GPU buffer
ck_tile
::
DeviceMem
a_buf
(
a_host
);
ck_tile
::
DeviceMem
a_buf
(
a_host
);
ck_tile
::
DeviceMem
g_perm_buf
(
g_host
);
ck_tile
::
DeviceMem
g_perm_buf
(
g_host
);
...
@@ -441,8 +477,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -441,8 +477,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
std
::
cout
<<
std
::
flush
<<
std
::
endl
;
std
::
cout
<<
std
::
flush
<<
std
::
endl
;
return
pass
;
return
pass
;
}
}
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
View file @
b616b254
...
@@ -213,9 +213,9 @@ struct FusedMoeGemmGlKernel
...
@@ -213,9 +213,9 @@ struct FusedMoeGemmGlKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
{
//constexpr index_t block_m = BlockShape::Block_M0;
//
constexpr index_t block_m = BlockShape::Block_M0;
int
max_num_tokens_padded
=
hargs
.
max_num_tokens_padded
;
int
max_num_tokens_padded
=
hargs
.
max_num_tokens_padded
;
//hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
//
hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
return
Partitioner
::
GridSize
(
max_num_tokens_padded
,
hargs
.
intermediate_size
);
return
Partitioner
::
GridSize
(
max_num_tokens_padded
,
hargs
.
intermediate_size
);
}
}
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
b616b254
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
b616b254
...
@@ -191,12 +191,13 @@ struct FusedMoeGemmPipeline_General
...
@@ -191,12 +191,13 @@ struct FusedMoeGemmPipeline_General
block_sync_lds
();
block_sync_lds
();
gemm_0
(
s_acc
,
a_lds_win
,
g_dram_block
);
gemm_0
(
s_acc
,
a_lds_win
,
g_dram_block
);
}
}
#if 1
PrintMem
(
s_acc
);
#endif
// relu
// relu
const
auto
activation
=
ck_tile
::
element_wise
::
Gelu
{};
const
auto
activation
=
ck_tile
::
element_wise
::
Gelu
{};
tile_elementwise_inout
(
activation
,
s_acc
,
s_acc
);
tile_elementwise_inout
(
activation
,
s_acc
,
s_acc
);
#if 0
PrintMem(s_acc);
#endif
// move sacc to LDS
// move sacc to LDS
auto
bridge_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
auto
bridge_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>());
smem_0
,
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>());
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
b616b254
...
@@ -175,7 +175,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
...
@@ -175,7 +175,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
{
{
using
WG
=
decltype
(
GetWarpGemm0
<
Problem
>
());
using
WG
=
decltype
(
GetWarpGemm0
<
Problem
>
());
using
S_
=
typename
Problem
::
BlockShape
;
using
S_
=
typename
Problem
::
BlockShape
;
static_assert
(
S_
::
WarpPerBlock_N0
==
4
);
static_assert
(
S_
::
WarpPerBlock_N0
==
4
);
constexpr
auto
g_outer_dstr_enc
=
tile_distribution_encoding
<
constexpr
auto
g_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
S_
::
WarpPerBlock_M0
>
,
sequence
<
S_
::
WarpPerBlock_M0
>
,
tuple
<
sequence
<
S_
::
Repeat_N0
,
S_
::
WarpPerBlock_N0
>
,
sequence
<
S_
::
Repeat_K0
>>
,
tuple
<
sequence
<
S_
::
Repeat_N0
,
S_
::
WarpPerBlock_N0
>
,
sequence
<
S_
::
Repeat_K0
>>
,
...
@@ -240,9 +240,10 @@ struct FusedMoeGemmPipelineGeneralPolicy
...
@@ -240,9 +240,10 @@ struct FusedMoeGemmPipelineGeneralPolicy
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
constexpr
auto
y_outer_dstr_enc
=
tile_distribution_encoding
<
constexpr
auto
y_outer_dstr_enc
=
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
S_
::
Repeat_M1
,
S_
::
WarpPerBlock_M1
>
,
sequence
<
S_
::
WarpPerBlock_K1
,
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
S_
::
Repeat_M1
,
S_
::
WarpPerBlock_M1
>
,
sequence
<
S_
::
WarpPerBlock_K1
,
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>
,
...
@@ -260,9 +261,10 @@ struct FusedMoeGemmPipelineGeneralPolicy
...
@@ -260,9 +261,10 @@ struct FusedMoeGemmPipelineGeneralPolicy
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
constexpr
auto
d_outer_dstr_enc
=
tile_distribution_encoding
<
constexpr
auto
d_outer_dstr_enc
=
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
S_
::
Repeat_N1
,
S_
::
WarpPerBlock_N1
>
,
sequence
<
S_
::
WarpPerBlock_K1
,
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
S_
::
Repeat_N1
,
S_
::
WarpPerBlock_N1
>
,
sequence
<
S_
::
WarpPerBlock_K1
,
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>
,
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
View file @
b616b254
...
@@ -52,16 +52,16 @@ struct BlockGemmARegBRegCRegV2
...
@@ -52,16 +52,16 @@ struct BlockGemmARegBRegCRegV2
// M->N Warp
// M->N Warp
// constexpr auto a_block_outer_dstr_encoding =
// constexpr auto a_block_outer_dstr_encoding =
// tile_distribution_encoding<sequence<NWarp>,
// tile_distribution_encoding<sequence<NWarp>,
// tuple<sequence<MIterPerWarp, MWarp>,
sequence<KIterPerWarp>>,
// tuple<sequence<MIterPerWarp, MWarp>,
// tuple<sequence<1, 0>>,
//
sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
// tuple<sequence<1, 0>>,
// tuple<sequence<1, 0>>,
// sequence<1, 2>,
// sequence<1, 2>,
// sequence<0, 0>>{};
// sequence<0, 0>>{};
// constexpr auto b_block_outer_dstr_encoding =
// constexpr auto b_block_outer_dstr_encoding =
// tile_distribution_encoding<sequence<MWarp>,
// tile_distribution_encoding<sequence<MWarp>,
// tuple<sequence<NIterPerWarp, NWarp>,
sequence<KIterPerWarp>>,
// tuple<sequence<NIterPerWarp, NWarp>,
// tuple<sequence<0, 1>>,
//
sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
// tuple<sequence<0, 1>>,
// tuple<sequence<0, 1>>,
// sequence<1, 2>,
// sequence<1, 2>,
// sequence<0, 0>>{};
// sequence<0, 0>>{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment