Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2880f7a5
Commit
2880f7a5
authored
Jan 04, 2025
by
OscarXu
Browse files
[CK_TILE] Support moe with up gemm
parent
37b35146
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
112 additions
and
53 deletions
+112
-53
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
+12
-0
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
...ile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
+21
-20
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
..._tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
+1
-1
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
+3
-0
example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
+3
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+3
-2
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
+69
-30
No files found.
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
View file @
2880f7a5
...
@@ -22,12 +22,24 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
...
@@ -22,12 +22,24 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
}
else
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
0
)
{
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
1024
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
{
{
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
}
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
0
)
{
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
1024
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
// clang-format on
// clang-format on
return
r
;
return
r
;
}
}
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp
View file @
2880f7a5
...
@@ -16,26 +16,27 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
...
@@ -16,26 +16,27 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
{
{
using
f_traits
=
ck_tile
::
FusedMoeGemmTraits
<
Ts_
::
GateOnly
,
Ts_
::
FusedQuant
==
1
,
1
/*atomic*/
>
;
using
f_traits
=
ck_tile
::
FusedMoeGemmTraits
<
Ts_
::
GateOnly
,
Ts_
::
FusedQuant
==
1
,
1
/*atomic*/
>
;
using
f_shape
=
ck_tile
::
FusedMoeGemmShape
<
typename
Ts_
::
BlockTile_0
,
using
f_shape
=
ck_tile
::
FusedMoeGemmShape
<
typename
Ts_
::
BlockTile_0
,
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts_
::
WarpTile_0
,
typename
Ts_
::
WarpTile_0
,
typename
Ts_
::
BlockTile_1
,
typename
Ts_
::
BlockTile_1
,
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts_
::
WarpTile_0
>
;
typename
Ts_
::
WarpTile_0
>
;
using
f_problem
=
using
f_problem
=
ck_tile
::
FusedMoeGemmPipelineProblem
<
typename
Ts_
::
ADataType
,
ck_tile
::
FusedMoeGemmPipelineProblem
<
typename
Ts_
::
ADataType
,
typename
Ts_
::
GDataType
,
typename
Ts_
::
GDataType
,
typename
Ts_
::
DDataType
,
typename
Ts_
::
DDataType
,
typename
Ts_
::
AccDataType
,
typename
Ts_
::
AccDataType
,
typename
Ts_
::
ODataType
,
typename
Ts_
::
ODataType
,
typename
Ts_
::
AScaleDataType
,
typename
Ts_
::
AScaleDataType
,
typename
Ts_
::
GScaleDataType
,
typename
Ts_
::
GScaleDataType
,
typename
Ts_
::
DScaleDataType
,
typename
Ts_
::
DScaleDataType
,
typename
Ts_
::
YSmoothScaleDataType
,
typename
Ts_
::
YSmoothScaleDataType
,
typename
Ts_
::
TopkWeightDataType
,
typename
Ts_
::
TopkWeightDataType
,
typename
Ts_
::
IndexDataType
,
typename
Ts_
::
IndexDataType
,
// ck_tile::element_wise::FastGeluAsm, //
ck_tile
::
element_wise
::
FastGeluAsm
,
// TODO: hardcoded
// TODO: hardcoded
f_shape
,
ck_tile
::
element_wise
::
Silu
,
f_traits
>
;
f_shape
,
f_traits
>
;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using
f_pipeline
=
ck_tile
::
FusedMoeGemmPipeline_FlatmmUk
<
f_problem
>
;
using
f_pipeline
=
ck_tile
::
FusedMoeGemmPipeline_FlatmmUk
<
f_problem
>
;
...
...
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
View file @
2880f7a5
...
@@ -40,7 +40,7 @@ struct fmoe_ // traits, ugly name, only used for internal
...
@@ -40,7 +40,7 @@ struct fmoe_ // traits, ugly name, only used for internal
static
constexpr
ck_tile
::
index_t
BH_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
2
>
{});
// block hidden
static
constexpr
ck_tile
::
index_t
BH_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
2
>
{});
// block hidden
static
constexpr
ck_tile
::
index_t
BD_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
3
>
{});
// block down
static
constexpr
ck_tile
::
index_t
BD_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
3
>
{});
// block down
using
BlockTile_0
=
ck_tile
::
sequence
<
BT_
,
BI_
,
BH_
>
;
using
BlockTile_0
=
ck_tile
::
sequence
<
BT_
,
BI_
/
(
GateOnly_
?
1
:
2
)
,
BH_
>
;
using
WarpPerBlock_0
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpPerBlock_0
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
...
...
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
View file @
2880f7a5
...
@@ -10,5 +10,8 @@
...
@@ -10,5 +10,8 @@
template
float
fused_moegemm_
<
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
1024
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
// clang-format on
example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
View file @
2880f7a5
...
@@ -10,5 +10,8 @@
...
@@ -10,5 +10,8 @@
template
float
fused_moegemm_
<
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
1024
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
// clang-format on
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
2880f7a5
...
@@ -228,7 +228,8 @@ struct FusedMoeGemmKernel
...
@@ -228,7 +228,8 @@ struct FusedMoeGemmKernel
int
max_num_tokens_padded
=
int
max_num_tokens_padded
=
hargs
.
topk
*
hargs
.
num_tokens
+
hargs
.
num_experts
*
block_m
-
hargs
.
topk
;
hargs
.
topk
*
hargs
.
num_tokens
+
hargs
.
num_experts
*
block_m
-
hargs
.
topk
;
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
return
Partitioner
::
GridSize
(
max_num_tokens_padded
,
hargs
.
intermediate_size
);
return
Partitioner
::
GridSize
(
max_num_tokens_padded
,
hargs
.
intermediate_size
/
(
IsGateOnly
?
1
:
2
));
}
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
BlockSize_
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
BlockSize_
);
}
...
@@ -382,7 +383,7 @@ struct FusedMoeGemmKernel
...
@@ -382,7 +383,7 @@ struct FusedMoeGemmKernel
auto
o_window
=
[
&
]()
{
auto
o_window
=
[
&
]()
{
ODataType
*
o_ptr
=
reinterpret_cast
<
ODataType
*>
(
kargs
.
o_ptr
);
ODataType
*
o_ptr
=
reinterpret_cast
<
ODataType
*>
(
kargs
.
o_ptr
);
auto
o_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
,
auto
o_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
memory_operation_enum
::
atomic_add
>
(
o_ptr
,
o_ptr
,
make_tuple
(
kargs
.
num_tokens
,
kargs
.
hidden_size
),
make_tuple
(
kargs
.
num_tokens
,
kargs
.
hidden_size
),
make_tuple
(
kargs
.
stride_token
,
1
),
make_tuple
(
kargs
.
stride_token
,
1
),
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
View file @
2880f7a5
...
@@ -73,7 +73,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -73,7 +73,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
constexpr
index_t
smem_0
=
Policy
::
template
GetUK_0
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_0
=
Policy
::
template
GetUK_0
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_1
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_1
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_bridge
=
constexpr
index_t
smem_bridge
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
(
IsGateOnly
?
1
:
2
);
return
max
(
smem_0
,
max
(
smem_1
,
smem_bridge
));
return
max
(
smem_0
,
max
(
smem_1
,
smem_bridge
));
}
}
...
@@ -165,7 +165,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -165,7 +165,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
index_t
intermediate_tile_id
)
index_t
intermediate_tile_id
)
{
{
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
ck_tile
::
index_t
shared_intermediate_size_0
=
kargs
.
intermediate_size
;
ck_tile
::
index_t
shared_intermediate_size_0
=
kargs
.
intermediate_size
/
hidden_radio_0
;
ck_tile
::
index_t
shared_intermediate_size_1
=
kargs
.
intermediate_size
/
hidden_radio_0
;
ck_tile
::
index_t
shared_intermediate_size_1
=
kargs
.
intermediate_size
/
hidden_radio_0
;
index_t
nr_0
=
shared_intermediate_size_0
/
BlockShape
::
Warp_N0
;
// divide N in W
index_t
nr_0
=
shared_intermediate_size_0
/
BlockShape
::
Warp_N0
;
// divide N in W
...
@@ -175,7 +175,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -175,7 +175,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
index_t
expert_stride_0
=
shared_intermediate_size_0
*
kargs
.
hidden_size
;
index_t
expert_stride_0
=
shared_intermediate_size_0
*
kargs
.
hidden_size
*
hidden_radio_0
;
index_t
expert_stride_1
=
shared_intermediate_size_1
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
shared_intermediate_size_1
*
kargs
.
hidden_size
;
// nr*kr*w
// nr*kr*w
...
@@ -200,10 +200,10 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -200,10 +200,10 @@ struct FusedMoeGemmPipeline_FlatmmUk
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
auto
g_win
=
[
&
]()
{
auto
g
u
_win
_gen
=
[
&
](
auto
ptr_offset
)
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
interm_idx_nr0
*
kr_0
*
BlockShape
::
Block_W0
;
ptr_offset
+
interm_idx_nr0
*
kr_0
*
BlockShape
::
Block_W0
;
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
g_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShape
::
Block_W0
>
{}),
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShape
::
Block_W0
>
{}),
...
@@ -220,7 +220,8 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -220,7 +220,8 @@ struct FusedMoeGemmPipeline_FlatmmUk
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>(),
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
sequence
<
0
,
1
,
1
>
{});
return
g_window_
;
return
g_window_
;
}();
};
auto
g_win
=
gu_win_gen
(
0
);
auto
g_res
=
g_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
auto
g_res
=
g_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
auto
g_coords
=
generate_tuple
([
&
](
auto
i
)
{
return
g_win
.
cached_coords_
[
i
].
get_offset
();
},
auto
g_coords
=
generate_tuple
([
&
](
auto
i
)
{
return
g_win
.
cached_coords_
[
i
].
get_offset
();
},
...
@@ -309,32 +310,70 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -309,32 +310,70 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto
w_scale
=
GetWeightScale
(
auto
w_scale
=
GetWeightScale
(
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
uk_0_g
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
acc_0
=
uk_0
(
a_res
,
auto
acc_0
=
uk_0_g
(
a_res
,
a_coords
,
a_coords
,
g_res
,
g_res
,
g_coords
,
g_coords
,
smem
,
smem
,
kargs
.
hidden_size
,
kargs
.
hidden_size
,
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
sweep_tile
(
// fast GeLu
acc_0
,
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
GateActivation
,
[
&
](
auto
idx0
,
auto
idx1
)
{
ck_tile
::
element_wise
::
FastGeluAsm
>
)
fp32x2_t
v_
{
acc_0
(
idx0
),
acc_0
(
idx1
)};
{
typename
Problem
::
GateActivation
{}(
v_
,
v_
);
sweep_tile
(
acc_0
(
idx0
)
=
v_
.
x
;
acc_0
,
acc_0
(
idx1
)
=
v_
.
y
;
[
&
](
auto
idx0
,
auto
idx1
)
{
},
fp32x2_t
v_
{
acc_0
(
idx0
),
acc_0
(
idx1
)};
sequence
<
1
,
2
>
{});
typename
Problem
::
GateActivation
{}(
v_
,
v_
);
acc_0
(
idx0
)
=
v_
.
x
;
auto
y_pre
=
cast_tile
<
YDataType
>
(
acc_0
);
acc_0
(
idx1
)
=
v_
.
y
;
},
sequence
<
1
,
2
>
{});
}
else
{
sweep_tile
(
acc_0
,
[
&
](
auto
idx0
)
{
typename
Problem
::
GateActivation
{}(
acc_0
(
idx0
),
acc_0
(
idx0
));
},
sequence
<
1
,
1
>
{});
}
auto
y_pre
=
acc_0
;
block_sync_lds
();
block_sync_lds
();
store_tile
(
bridge_sst_win
,
y_pre
);
// up
if
(
!
IsGateOnly
)
{
// up ptr. add hafl expoert_stride_0 as offset.
auto
u_win
=
gu_win_gen
(
shared_intermediate_size_0
*
kargs
.
hidden_size
);
auto
u_res
=
u_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
auto
u_coords
=
generate_tuple
([
&
](
auto
i
)
{
return
u_win
.
cached_coords_
[
i
].
get_offset
();
},
number
<
decltype
(
u_win
)
::
NumAccess_NonLinear
>
{});
// reuse UK0
auto
uk_0_u
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
acc_0_u
=
uk_0_u
(
a_res
,
a_coords
,
u_res
,
u_coords
,
smem
,
kargs
.
hidden_size
,
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
// elementwise mul gate*up.
sweep_tile
(
y_pre
,
[
&
](
auto
idx0
)
{
y_pre
(
idx0
)
=
y_pre
(
idx0
)
*
acc_0_u
(
idx0
);
},
sequence
<
1
,
1
>
{});
block_sync_lds
();
}
store_tile
(
bridge_sst_win
,
cast_tile
<
YDataType
>
(
y_pre
));
block_sync_lds
();
block_sync_lds
();
auto
uk_1
=
Policy
::
template
GetUK_1
<
Problem
>();
auto
uk_1
=
Policy
::
template
GetUK_1
<
Problem
>();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment