Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
53a7ebd8
Unverified
Commit
53a7ebd8
authored
Jun 17, 2024
by
Lianmin Zheng
Committed by
GitHub
Jun 17, 2024
Browse files
Update fused_moe (#553)
parent
ad5f04d6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
220 additions
and
183 deletions
+220
-183
python/sglang/srt/layers/fused_moe.py
python/sglang/srt/layers/fused_moe.py
+220
-183
No files found.
python/sglang/srt/layers/fused_moe.py
View file @
53a7ebd8
...
@@ -9,9 +9,9 @@ from typing import Any, Dict, Optional, Tuple
...
@@ -9,9 +9,9 @@ from typing import Any, Dict, Optional, Tuple
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_hip
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -108,16 +108,12 @@ def fused_moe_kernel(
...
@@ -108,16 +108,12 @@ def fused_moe_kernel(
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
offs_k
[
None
,
:]
*
stride_ak
)
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
b_ptrs
=
(
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
b_ptr
offs_bn
[
None
,
:]
*
stride_bn
)
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
)
if
use_fp8
:
if
use_fp8
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
a_scale
=
tl
.
load
(
a_scale_ptr
)
...
@@ -133,12 +129,13 @@ def fused_moe_kernel(
...
@@ -133,12 +129,13 @@ def fused_moe_kernel(
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
# Load the next block of A and B, generate a mask by checking the
# Load the next block of A and B, generate a mask by checking the
# K dimension.
# K dimension.
a
=
tl
.
load
(
a
=
tl
.
load
(
a_ptrs
,
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
,
other
=
0.0
)
)
b
=
tl
.
load
(
b_ptrs
,
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
# We accumulate along the K dimension.
# We accumulate along the K dimension.
if
use_fp8
:
if
use_fp8
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
...
@@ -149,7 +146,9 @@ def fused_moe_kernel(
...
@@ -149,7 +146,9 @@ def fused_moe_kernel(
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
accumulator
=
accumulator
*
moe_weight
[:,
None
]
if
use_fp8
:
if
use_fp8
:
...
@@ -159,14 +158,15 @@ def fused_moe_kernel(
...
@@ -159,14 +158,15 @@ def fused_moe_kernel(
# -----------------------------------------------------------
# -----------------------------------------------------------
# Write back the block of the output
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
def
moe_align_block_size
(
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
num_experts
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Aligns the token distribution across experts to be compatible with block
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
size for matrix multiplication.
...
@@ -205,38 +205,32 @@ def moe_align_block_size(
...
@@ -205,38 +205,32 @@ def moe_align_block_size(
by block_size for proper block matrix operations.
by block_size for proper block matrix operations.
"""
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
(
sorted_ids
=
torch
.
empty
(
(
max_num_tokens_padded
,
),
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
dtype
=
torch
.
int32
,
)
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
expert_ids
=
torch
.
empty
(
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,
),
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
dtype
=
torch
.
int32
,
)
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
ops
.
moe_align_block_size
(
dtype
=
torch
.
int32
,
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
device
=
topk_ids
.
device
)
)
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
def
invoke_fused_moe_kernel
(
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B
:
torch
.
Tensor
,
B_scale
:
Optional
[
torch
.
Tensor
],
C
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
sorted_token_ids
:
torch
.
Tensor
,
B_scale
:
Optional
[
torch
.
Tensor
],
expert_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
expert_ids
:
torch
.
Tensor
,
use_fp8
:
bool
)
->
None
:
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8
:
bool
,
)
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
...
@@ -247,10 +241,8 @@ def invoke_fused_moe_kernel(
...
@@ -247,10 +241,8 @@ def invoke_fused_moe_kernel(
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
assert
B_scale
is
not
None
assert
B_scale
is
not
None
grid
=
lambda
META
:
(
grid
=
lambda
META
:
(
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
"BLOCK_SIZE_M"
])
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
'BLOCK_SIZE_N'
]),
)
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
"BLOCK_SIZE_N"
]),
)
fused_moe_kernel
[
grid
](
fused_moe_kernel
[
grid
](
A
,
A
,
...
@@ -288,7 +280,8 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
...
@@ -288,7 +280,8 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
@
functools
.
lru_cache
@
functools
.
lru_cache
def
get_moe_configs
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
])
->
Optional
[
Dict
[
int
,
Any
]]:
def
get_moe_configs
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
])
->
Optional
[
Dict
[
int
,
Any
]]:
"""
"""
Return optimized configurations for the fused MoE kernel.
Return optimized configurations for the fused MoE kernel.
...
@@ -303,11 +296,11 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
...
@@ -303,11 +296,11 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
json_file_name
=
get_config_file_name
(
E
,
N
,
dtype
)
json_file_name
=
get_config_file_name
(
E
,
N
,
dtype
)
config_file_path
=
os
.
path
.
join
(
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
)
if
os
.
path
.
exists
(
config_file_path
):
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
"Using configuration from %s for MoE layer."
,
config_file_path
)
logger
.
info
(
"Using configuration from %s for MoE layer."
,
config_file_path
)
# If a configuration has been found, return it
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
...
@@ -316,6 +309,165 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
...
@@ -316,6 +309,165 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
return
None
return
None
def
get_default_config
(
M
:
int
,
E
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
dtype
:
Optional
[
str
],
)
->
Dict
[
str
,
int
]:
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
8
}
if
M
<=
E
:
config
=
{
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
64
,
'GROUP_SIZE_M'
:
1
}
return
config
def
fused_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
M
,
_
=
hidden_states
.
shape
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
token_expert_indicies
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
ops
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
.
float
(),
# TODO(woosuk): Optimize this.
)
del
token_expert_indicies
# Not used. Will be used in the future.
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
):
# Check constraints.
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
M
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
if
override_config
:
config
=
override_config
else
:
# First try to load optimal config from the file
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
],
"float8"
if
use_fp8
else
None
)
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1
.
shape
[
2
],
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
)
intermediate_cache1
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache2
=
torch
.
empty
((
M
*
topk_ids
.
shape
[
1
],
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
intermediate_cache3
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
)
compute_type
=
(
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
)
invoke_fused_moe_kernel
(
hidden_states
,
w1
,
intermediate_cache1
,
a1_scale
,
w1_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache3
,
a2_scale
,
w2_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
if
inplace
:
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
,
out
=
hidden_states
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
def
fused_moe
(
def
fused_moe
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
@@ -358,134 +510,19 @@ def fused_moe(
...
@@ -358,134 +510,19 @@ def fused_moe(
- torch.Tensor: The output tensor after applying the MoE layer.
- torch.Tensor: The output tensor after applying the MoE layer.
"""
"""
# Check constraints.
# Check constraints.
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
M
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
if
is_hip
():
# The MoE kernels are not yet supported on ROCm.
routing_weights
=
torch
.
softmax
(
gating_output
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
=
torch
.
topk
(
routing_weights
,
topk
,
dim
=-
1
)
else
:
import
vllm._moe_C
as
moe_kernels
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
token_expert_indicies
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
moe_kernels
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
.
float
(),
# TODO(woosuk): Optimize this.
)
del
token_expert_indicies
# Not used. Will be used in the future.
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
if
override_config
:
config
=
override_config
else
:
# First try to load optimal config from the file
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
],
"float8"
if
use_fp8
else
None
)
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Else use the default config
config
=
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
,
}
if
M
<=
E
:
config
=
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
,
}
intermediate_cache1
=
torch
.
empty
(
(
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
],
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache3
=
torch
.
empty
(
(
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
topk_ids
,
config
[
"BLOCK_SIZE_M"
],
E
)
compute_type
=
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
invoke_fused_moe_kernel
(
hidden_states
,
w1
,
intermediate_cache1
,
a1_scale
,
w1_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
invoke_fused_moe_kernel
(
return
fused_experts
(
hidden_states
,
intermediate_cache2
,
w1
,
w2
,
w2
,
intermediate_cache3
,
topk_weights
,
a2_scale
,
topk_ids
,
w2_scale
,
inplace
=
inplace
,
topk_weights
,
override_config
=
override_config
,
topk_ids
,
use_fp8
=
use_fp8
,
sorted_token_ids
,
w1_scale
=
w1_scale
,
expert_ids
,
w2_scale
=
w2_scale
,
num_tokens_post_padded
,
a1_scale
=
a1_scale
,
True
,
a2_scale
=
a2_scale
)
1
,
\ No newline at end of file
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
)
if
inplace
:
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
,
out
=
hidden_states
,
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment