Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b509db58
Unverified
Commit
b509db58
authored
Nov 24, 2024
by
Yineng Zhang
Committed by
GitHub
Nov 24, 2024
Browse files
feat: remove the dependency on FusedMoE (#2153)
parent
dbe17293
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1602 additions
and
7 deletions
+1602
-7
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+15
-5
python/sglang/srt/layers/triton_fused_moe/__init__.py
python/sglang/srt/layers/triton_fused_moe/__init__.py
+44
-0
python/sglang/srt/layers/triton_fused_moe/configs/README
python/sglang/srt/layers/triton_fused_moe/configs/README
+10
-0
python/sglang/srt/layers/triton_fused_moe/fused_moe.py
python/sglang/srt/layers/triton_fused_moe/fused_moe.py
+858
-0
python/sglang/srt/layers/triton_fused_moe/layer.py
python/sglang/srt/layers/triton_fused_moe/layer.py
+631
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+43
-1
No files found.
python/sglang/srt/layers/quantization/__init__.py
View file @
b509db58
...
@@ -57,12 +57,23 @@ __all__ = [
...
@@ -57,12 +57,23 @@ __all__ = [
"QUANTIZATION_METHODS"
,
"QUANTIZATION_METHODS"
,
]
]
"""
def fp8_get_quant_method(
def
fp8_get_quant_method
(
self
,
layer
,
prefix
):
self, layer: torch.nn.Module, prefix: str
from
vllm.model_executor.layers.linear
import
LinearBase
) -> Optional["QuantizeMethodBase"]:
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8LinearMethod
,
Fp8MoEMethod
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
,
)
from
sglang.srt.layers.triton_fused_moe.layer
import
FusedMoE
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
ignored_layers
):
if
is_layer_skipped
(
prefix
,
self
.
ignored_layers
):
from
sglang.srt.layers.linear
import
UnquantizedLinearMethod
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
return
Fp8LinearMethod
(
self
)
return
Fp8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
elif
isinstance
(
layer
,
FusedMoE
):
...
@@ -71,4 +82,3 @@ def fp8_get_quant_method(
...
@@ -71,4 +82,3 @@ def fp8_get_quant_method(
setattr
(
Fp8Config
,
"get_quant_method"
,
fp8_get_quant_method
)
setattr
(
Fp8Config
,
"get_quant_method"
,
fp8_get_quant_method
)
"""
python/sglang/srt/layers/triton_fused_moe/__init__.py
0 → 100644
View file @
b509db58
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Optional
import
sglang.srt.layers.triton_fused_moe.fused_moe
# noqa
from
sglang.srt.layers.triton_fused_moe.fused_moe
import
(
fused_experts
,
fused_topk
,
get_config_file_name
,
grouped_topk
,
)
from
sglang.srt.layers.triton_fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
)
_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
@
contextmanager
def
override_config
(
config
):
global
_config
old_config
=
_config
_config
=
config
yield
_config
=
old_config
def
get_config
()
->
Optional
[
Dict
[
str
,
Any
]]:
return
_config
__all__
=
[
"FusedMoE"
,
"FusedMoEMethodBase"
,
"FusedMoeWeightScaleSupported"
,
"override_config"
,
"get_config"
,
"fused_moe"
,
"fused_topk"
,
"fused_experts"
,
"get_config_file_name"
,
"grouped_topk"
,
]
python/sglang/srt/layers/triton_fused_moe/configs/README
0 → 100644
View file @
b509db58
This directory contains tuned configurations for different settings of the fused_moe kernel.
For different settings of
- E (number of experts)
- N (intermediate size)
- device_name (torch.cuda.get_device_name())
the JSON file contains a mapping from M (batch size) to the chosen configuration.
The example configurations provided are for the Mixtral model for TP2 on H100
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have
N = 7168 and for TP4 we have N = 3584.
python/sglang/srt/layers/triton_fused_moe/fused_moe.py
0 → 100644
View file @
b509db58
"""Fused MoE kernel."""
import
functools
import
json
import
logging
import
os
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
from
vllm
import
_custom_ops
as
ops
from
sglang.srt.utils
import
direct_register_custom_op
,
get_device_name
logger
=
logging
.
getLogger
(
__name__
)
@
triton
.
jit
def
fused_moe_kernel
(
# Pointers to matrices
a_ptr
,
b_ptr
,
c_ptr
,
a_scale_ptr
,
b_scale_ptr
,
topk_weights_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
N
,
K
,
EM
,
num_valid_tokens
,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_bse
,
stride_bsn
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_fp8_w8a8
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
EM
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
((
pid
%
num_pid_in_group
)
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
)
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
return
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
token_mask
=
offs_token
<
num_valid_tokens
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
b_ptrs
=
(
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
)
if
use_int8_w8a16
:
b_scale_ptrs
=
(
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bn
[
None
,
:]
*
stride_bsn
)
b_scale
=
tl
.
load
(
b_scale_ptrs
)
if
use_fp8_w8a8
:
a_scale
=
tl
.
load
(
a_scale_ptr
)
b_scale
=
tl
.
load
(
b_scale_ptr
+
off_experts
)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
# We accumulate along the K dimension.
if
use_int8_w8a16
:
accumulator
=
tl
.
dot
(
a
,
b
.
to
(
compute_type
),
acc
=
accumulator
)
elif
use_fp8_w8a8
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
else
:
accumulator
+=
tl
.
dot
(
a
,
b
)
# Advance the ptrs to the next K block.
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
if
use_int8_w8a16
:
accumulator
=
(
accumulator
*
b_scale
).
to
(
compute_type
)
elif
use_fp8_w8a8
:
accumulator
=
(
accumulator
*
a_scale
*
b_scale
).
to
(
compute_type
)
else
:
accumulator
=
accumulator
.
to
(
compute_type
)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
(
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
expert_ids
=
torch
.
empty
(
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
)
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
B_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
mul_routed_weight
:
bool
,
top_k
:
int
,
config
:
Dict
[
str
,
Any
],
compute_type
:
tl
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
)
->
None
:
assert
topk_weights
.
stride
(
1
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
if
use_fp8_w8a8
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
assert
B_scale
is
not
None
elif
use_int8_w8a16
:
assert
B_scale
is
not
None
else
:
assert
A_scale
is
None
assert
B_scale
is
None
grid
=
lambda
META
:
(
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
"BLOCK_SIZE_N"
]),
)
fused_moe_kernel
[
grid
](
A
,
B
,
C
,
A_scale
,
B_scale
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
B
.
shape
[
1
],
B
.
shape
[
2
],
sorted_token_ids
.
shape
[
0
],
topk_ids
.
numel
(),
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
B
.
stride
(
2
),
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
B_scale
.
stride
(
0
)
if
B_scale
is
not
None
and
use_int8_w8a16
else
0
,
B_scale
.
stride
(
1
)
if
B_scale
is
not
None
and
use_int8_w8a16
else
0
,
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
def
get_config_file_name
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
])
->
str
:
device_name
=
get_device_name
().
replace
(
" "
,
"_"
)
dtype_selector
=
""
if
not
dtype
else
f
",dtype=
{
dtype
}
"
return
f
"E=
{
E
}
,N=
{
N
}
,device_name=
{
device_name
}{
dtype_selector
}
.json"
@
functools
.
lru_cache
def
get_moe_configs
(
E
:
int
,
N
:
int
,
dtype
:
Optional
[
str
])
->
Optional
[
Dict
[
int
,
Any
]]:
"""
Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the fused_moe kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
json_file_name
=
get_config_file_name
(
E
,
N
,
dtype
)
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
"Using configuration from %s for MoE layer."
,
config_file_path
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
# If no optimized configuration is available, we will use the default
# configuration
logger
.
warning
(
(
"Using default MoE config. Performance might be sub-optimal! "
"Config file not found at %s"
),
config_file_path
,
)
return
None
def
get_default_config
(
M
:
int
,
E
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
dtype
:
Optional
[
str
],
is_marlin
:
bool
,
)
->
Dict
[
str
,
int
]:
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
}
# A heuristic: fused marlin works faster with this config for small M
if
M
<=
E
or
(
is_marlin
and
M
<=
32
):
config
=
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
}
return
config
def
try_get_optimal_moe_config
(
w1_shape
:
Tuple
[
int
,
...],
w2_shape
:
Tuple
[
int
,
...],
top_k
:
int
,
dtype
:
Optional
[
str
],
M
:
int
,
is_marlin
:
bool
=
False
,
):
from
sglang.srt.layers.triton_fused_moe
import
get_config
override_config
=
get_config
()
if
override_config
:
config
=
override_config
else
:
# First try to load optimal config from the file
E
,
_
,
N
=
w2_shape
configs
=
get_moe_configs
(
E
,
N
,
dtype
)
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1_shape
[
2
],
top_k
,
dtype
,
is_marlin
)
return
config
def
fused_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
M
,
_
=
hidden_states
.
shape
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
token_expert_indicies
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
ops
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
.
float
(),
# TODO(woosuk): Optimize this.
)
del
token_expert_indicies
# Not used. Will be used in the future.
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
# This is used by the Deepseek-V2 model
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
num_token
=
scores
.
shape
[
0
]
group_scores
=
(
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
)
# [n, n_group]
group_idx
=
torch
.
topk
(
group_scores
,
k
=
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
# [n, top_k_group]
group_mask
=
torch
.
zeros_like
(
group_scores
)
# [n, n_group]
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# [n, n_group]
score_mask
=
(
group_mask
.
unsqueeze
(
-
1
)
.
expand
(
num_token
,
num_expert_group
,
scores
.
shape
[
-
1
]
//
num_expert_group
)
.
reshape
(
num_token
,
-
1
)
)
# [n, e]
tmp_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
# [n, e]
topk_weights
,
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
False
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
def
get_config_dtype_str
(
dtype
:
torch
.
dtype
,
use_int8_w8a16
:
Optional
[
bool
]
=
False
,
use_fp8_w8a8
:
Optional
[
bool
]
=
False
,
):
if
use_fp8_w8a8
:
return
"fp8_w8a8"
elif
use_int8_w8a16
:
return
"int8_w8a16"
elif
dtype
==
torch
.
float
:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
return
"float32"
return
None
def
inplace_fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
True
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
,
)
def
inplace_fused_experts_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
pass
direct_register_custom_op
(
op_name
=
"inplace_fused_experts"
,
op_func
=
inplace_fused_experts
,
mutates_args
=
[
"hidden_states"
],
fake_impl
=
inplace_fused_experts_fake
,
)
def
outplace_fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
fused_experts_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
False
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
,
)
def
outplace_fused_experts_fake
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
direct_register_custom_op
(
op_name
=
"outplace_fused_experts"
,
op_func
=
outplace_fused_experts
,
mutates_args
=
[],
fake_impl
=
outplace_fused_experts_fake
,
)
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
inplace
:
torch
.
ops
.
sglang
.
inplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
,
)
return
hidden_states
else
:
return
torch
.
ops
.
sglang
.
outplace_fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
use_fp8_w8a8
,
use_int8_w8a16
,
w1_scale
,
w2_scale
,
a1_scale
,
a2_scale
,
)
def
fused_experts_impl
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
inplace
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# Check constraints.
assert
hidden_states
.
shape
[
1
]
==
w1
.
shape
[
2
],
"Hidden size mismatch"
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
hidden_states
.
is_contiguous
(),
"Hidden_states must be contiguous"
assert
w1
.
is_contiguous
(),
"Expert weights1 must be contiguous"
assert
w2
.
is_contiguous
(),
"Expert weights2 must be contiguous"
assert
hidden_states
.
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
num_tokens
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
64
*
1024
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
config_dtype
=
get_config_dtype_str
(
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
dtype
=
hidden_states
.
dtype
,
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
topk_ids
.
shape
[
1
],
config_dtype
,
)
config
=
get_config_func
(
M
)
intermediate_cache1
=
torch
.
empty
(
(
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
],
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache3
=
torch
.
empty
(
(
M
,
topk_ids
.
shape
[
1
],
w2
.
shape
[
1
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
compute_type
=
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
if
inplace
:
out_hidden_states
=
hidden_states
else
:
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
for
chunk
in
range
((
num_tokens
//
CHUNK_SIZE
)
+
1
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
),
)
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
if
tokens_in_chunk
==
0
:
break
if
tokens_in_chunk
<
CHUNK_SIZE
and
chunk
>
0
:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1
=
intermediate_cache1
[:
tokens_in_chunk
]
intermediate_cache2
=
intermediate_cache2
[:
tokens_in_chunk
]
intermediate_cache3
=
intermediate_cache3
[:
tokens_in_chunk
]
config
=
get_config_func
(
tokens_in_chunk
)
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
curr_topk_ids
,
config
[
"BLOCK_SIZE_M"
],
E
)
invoke_fused_moe_kernel
(
curr_hidden_states
,
w1
,
intermediate_cache1
,
a1_scale
,
w1_scale
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache3
,
a2_scale
,
w2_scale
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
)
ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
)
return
out_hidden_states
def
fused_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
if
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
,
num_expert_group
,
topk_group
,
)
elif
custom_routing_function
is
None
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
else
:
topk_weights
,
topk_ids
=
custom_routing_function
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
return
fused_experts
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
inplace
=
inplace
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
)
python/sglang/srt/layers/triton_fused_moe/layer.py
0 → 100644
View file @
b509db58
from
abc
import
abstractmethod
from
enum
import
Enum
from
typing
import
Callable
,
List
,
Optional
,
Tuple
import
torch
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.custom_op_util
import
register_custom_op
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.utils
import
set_weight_attrs
if
torch
.
cuda
.
is_available
()
or
torch
.
hip
.
is_available
():
from
.fused_moe
import
fused_experts
else
:
fused_experts
=
None
# type: ignore
import
logging
logger
=
logging
.
getLogger
(
__name__
)
class
FusedMoeWeightScaleSupported
(
Enum
):
TENSOR
=
"tensor"
CHANNEL
=
"channel"
GROUP
=
"group"
class
FusedMoEMethodBase
(
QuantizeMethodBase
):
@
abstractmethod
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
raise
NotImplementedError
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
@
register_custom_op
(
"sglang_unquantized_fused_moe"
)
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
"""MoE method without quantization."""
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
=
x
,
layer
=
layer
,
router_logits
=
router_logits
,
top_k
=
top_k
,
renormalize
=
renormalize
,
use_grouped_topk
=
use_grouped_topk
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
)
def
forward_cuda
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
)
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
)
def
forward_cpu
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"The CPU backend currently does not support MoE."
)
def
forward_tpu
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
raise
NotImplementedError
(
"The TPU backend currently does not support MoE."
)
forward_native
=
forward_cuda
class
FusedMoE
(
torch
.
nn
.
Module
):
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
copy that naming convention here and handle any remapping in the
load_weights function in each model implementation.
Args:
num_experts: Number of experts in the model
top_k: Number of experts selected for each token
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
"""
def
__init__
(
self
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
False
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
):
super
().
__init__
()
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
tp_size
=
(
tp_size
if
tp_size
is
not
None
else
get_tensor_model_parallel_world_size
()
)
self
.
top_k
=
top_k
self
.
num_experts
=
num_experts
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
reduce_results
=
reduce_results
self
.
renormalize
=
renormalize
self
.
use_grouped_topk
=
use_grouped_topk
if
self
.
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
self
.
custom_routing_function
=
custom_routing_function
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
UnquantizedFusedMoEMethod
()
)
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
layer
=
self
,
num_experts
=
num_experts
,
hidden_size
=
hidden_size
,
intermediate_size
=
self
.
intermediate_size_per_partition
,
params_dtype
=
params_dtype
,
weight_loader
=
self
.
weight_loader
,
)
def
_load_per_tensor_weight_scale
(
self
,
shard_id
:
str
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
expert_id
:
int
,
):
param_data
=
param
.
data
# for per tensor weight quantization
if
shard_id
in
(
"w1"
,
"w3"
):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx
=
0
if
shard_id
==
"w1"
else
1
param_data
[
expert_id
][
idx
]
=
loaded_weight
# If we are in the row parallel case (down_proj)
elif
shard_id
==
"w2"
:
param_data
[
expert_id
]
=
loaded_weight
def
_load_model_weight_or_group_weight_scale
(
self
,
shard_dim
:
int
,
expert_data
:
torch
.
Tensor
,
shard_id
:
str
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
,
):
# Load grouped weight scales for group quantization
# or model weights
if
shard_id
==
"w2"
:
self
.
_load_w2
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
)
elif
shard_id
in
(
"w1"
,
"w3"
):
self
.
_load_w13
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
)
def
_load_per_channel_weight_scale
(
self
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
shard_id
:
str
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
,
):
# for per channel weight quantization
if
shard_id
==
"w2"
:
expert_data
.
copy_
(
loaded_weight
)
elif
shard_id
in
(
"w1"
,
"w3"
):
self
.
_load_w13
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
)
def
_load_w13
(
self
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
shard_id
:
str
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
,
):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size
=
expert_data
.
shape
[
shard_dim
]
//
2
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if
shard_id
==
"w1"
:
expert_data
=
expert_data
.
narrow
(
shard_dim
,
0
,
shard_size
)
# w3, up_proj: Load into second logical weight of w13.
else
:
assert
shard_id
==
"w3"
expert_data
=
expert_data
.
narrow
(
shard_dim
,
shard_size
,
shard_size
)
expert_data
.
copy_
(
loaded_weight
)
def
_load_w2
(
self
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
shard_id
:
str
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
,
):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size
=
expert_data
.
shape
[
shard_dim
]
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
# w2, down_proj: Load into only logical weight of w2.
expert_data
.
copy_
(
loaded_weight
)
def
_load_single_value
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
expert_id
:
int
):
param_data
=
param
.
data
# Input scales can be loaded directly and should be equal.
param_data
[
expert_id
]
=
loaded_weight
def
_load_g_idx
(
self
,
shard_id
:
str
,
expert_data
:
torch
.
Tensor
,
shard_dim
:
int
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
,
):
if
shard_id
==
"w2"
:
self
.
_load_w2
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
)
else
:
assert
shard_id
in
(
"w1"
,
"w3"
)
expert_data
.
copy_
(
loaded_weight
)
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
,
)
->
None
:
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight
=
(
loaded_weight
.
t
().
contiguous
()
if
(
self
.
quant_method
.
__class__
.
__name__
==
"CompressedTensorsWNA16MoEMethod"
)
else
loaded_weight
)
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
f
"got
{
shard_id
}
."
)
WEIGHT_SCALE_SUPPORTED
=
[
e
.
value
for
e
in
FusedMoeWeightScaleSupported
]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM
=
{
"w1"
:
0
,
"w2"
:
1
,
"w3"
:
0
}
expert_data
=
param
.
data
[
expert_id
]
tp_rank
=
get_tensor_model_parallel_rank
()
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed
=
getattr
(
param
,
"is_transposed"
,
False
)
shard_dim
=
SHARD_ID_TO_SHARDED_DIM
[
shard_id
]
if
is_transposed
:
shard_dim
=
~
shard_dim
# Case input scale: input_scale loading is only supported for fp8
if
"input_scale"
in
weight_name
:
# this is needed for compressed-tensors only
loaded_weight
=
loaded_weight
.
to
(
param
.
data
.
device
)
if
(
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
):
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param
.
data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
self
.
_load_single_value
(
param
=
param
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
return
# Case g_idx
if
"g_idx"
in
weight_name
:
self
.
_load_g_idx
(
shard_dim
=
0
,
shard_id
=
shard_id
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
)
return
# Case weight scales and zero_points
if
"scale"
in
weight_name
or
"zero"
in
weight_name
:
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
quant_method
=
getattr
(
param
,
"quant_method"
,
None
)
if
quant_method
==
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
:
self
.
_load_per_channel_weight_scale
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
)
elif
quant_method
==
FusedMoeWeightScaleSupported
.
GROUP
.
value
:
self
.
_load_model_weight_or_group_weight_scale
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
)
elif
quant_method
==
FusedMoeWeightScaleSupported
.
TENSOR
.
value
:
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
param
=
param
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
,
)
else
:
raise
ValueError
(
f
"quant method must be one of
{
WEIGHT_SCALE_SUPPORTED
}
"
)
return
# Case weight_shape
if
"weight_shape"
in
weight_name
:
# only required by compressed-tensors
self
.
_load_single_value
(
param
=
param
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
return
# Case model weights
if
"weight"
in
weight_name
:
self
.
_load_model_weight_or_group_weight_scale
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
)
return
@
staticmethod
def
select_experts
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
use_grouped_topk
:
bool
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
):
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
grouped_topk
,
)
# DeekSeekv2 uses grouped_top_k
if
use_grouped_topk
:
assert
topk_group
is
not
None
assert
num_expert_group
is
not
None
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
topk
=
top_k
,
renormalize
=
renormalize
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
)
elif
custom_routing_function
is
None
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
topk
=
top_k
,
renormalize
=
renormalize
,
)
else
:
topk_weights
,
topk_ids
=
custom_routing_function
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
topk
=
top_k
,
renormalize
=
renormalize
,
)
return
topk_weights
,
topk_ids
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
assert
self
.
quant_method
is
not
None
# Matrix multiply.
final_hidden_states
=
self
.
quant_method
.
apply
(
layer
=
self
,
x
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
use_grouped_topk
=
self
.
use_grouped_topk
,
topk_group
=
self
.
topk_group
,
num_expert_group
=
self
.
num_expert_group
,
custom_routing_function
=
self
.
custom_routing_function
,
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
@
classmethod
def
make_expert_params_mapping
(
cls
,
ckpt_gate_proj_name
:
str
,
ckpt_down_proj_name
:
str
,
ckpt_up_proj_name
:
str
,
num_experts
:
int
,
)
->
List
[
Tuple
[
str
,
str
,
int
,
str
]]:
return
[
# (param_name, weight_name, expert_id, shard_id)
(
(
"experts.w13_"
if
weight_name
in
[
ckpt_gate_proj_name
,
ckpt_up_proj_name
]
else
"experts.w2_"
),
f
"experts.
{
expert_id
}
.
{
weight_name
}
."
,
expert_id
,
shard_id
,
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
[
(
"w1"
,
ckpt_gate_proj_name
),
(
"w2"
,
ckpt_down_proj_name
),
(
"w3"
,
ckpt_up_proj_name
),
]
]
def
_load_fp8_scale
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
,
)
->
None
:
param_data
=
param
.
data
# Input scales can be loaded directly and should be equal.
if
"input_scale"
in
weight_name
:
if
(
param_data
[
expert_id
]
!=
1
and
(
param_data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
):
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param_data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
param_data
[
expert_id
]
=
loaded_weight
# Weight scales
elif
"weight_scale"
in
weight_name
:
# If we are in merged column case (gate_up_proj)
if
shard_id
in
(
"w1"
,
"w3"
):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx
=
0
if
shard_id
==
"w1"
else
1
param_data
[
expert_id
][
idx
]
=
loaded_weight
# If we are in the row parallel case (down_proj)
else
:
param_data
[
expert_id
]
=
loaded_weight
python/sglang/srt/models/deepseek_v2.py
View file @
b509db58
...
@@ -27,7 +27,6 @@ from vllm.distributed import (
...
@@ -27,7 +27,6 @@ from vllm.distributed import (
get_tp_group
,
get_tp_group
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -42,6 +41,7 @@ from sglang.srt.layers.linear import (
...
@@ -42,6 +41,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.triton_fused_moe
import
FusedMoE
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/utils.py
View file @
b509db58
...
@@ -31,7 +31,7 @@ import time
...
@@ -31,7 +31,7 @@ import time
import
warnings
import
warnings
from
importlib.metadata
import
PackageNotFoundError
,
version
from
importlib.metadata
import
PackageNotFoundError
,
version
from
io
import
BytesIO
from
io
import
BytesIO
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Protocol
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Protocol
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
psutil
import
psutil
...
@@ -45,6 +45,7 @@ from packaging import version as pkg_version
...
@@ -45,6 +45,7 @@ from packaging import version as pkg_version
from
starlette.routing
import
Mount
from
starlette.routing
import
Mount
from
torch
import
nn
from
torch
import
nn
from
torch.func
import
functional_call
from
torch.func
import
functional_call
from
torch.library
import
Library
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
from
triton.runtime.cache
import
(
from
triton.runtime.cache
import
(
FileCacheManager
,
FileCacheManager
,
...
@@ -930,3 +931,44 @@ def get_nvgpu_memory_capacity():
...
@@ -930,3 +931,44 @@ def get_nvgpu_memory_capacity():
def
crash_on_warnings
():
def
crash_on_warnings
():
# Crash on warning if we are running CI tests
# Crash on warning if we are running CI tests
return
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
)
==
"true"
return
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
)
==
"true"
def
get_device_name
(
device_id
:
int
=
0
)
->
str
:
if
hasattr
(
torch
,
"cuda"
)
and
torch
.
cuda
.
is_available
():
return
torch
.
cuda
.
get_device_name
(
device_id
)
if
hasattr
(
torch
,
"hip"
)
and
torch
.
hip
.
is_available
():
return
torch
.
hip
.
get_device_name
(
device_id
)
if
hasattr
(
torch
,
"xpu"
)
and
torch
.
xpu
.
is_available
():
return
torch
.
xpu
.
get_device_name
(
device_id
)
if
hasattr
(
torch
,
"hpu"
)
and
torch
.
hpu
.
is_available
():
return
torch
.
hpu
.
get_device_name
(
device_id
)
sglang_lib
=
Library
(
"sglang"
,
"FRAGMENT"
)
# noqa
def
direct_register_custom_op
(
op_name
:
str
,
op_func
:
Callable
,
mutates_args
:
List
[
str
],
fake_impl
:
Optional
[
Callable
]
=
None
,
target_lib
:
Optional
[
Library
]
=
None
,
):
import
torch.library
if
hasattr
(
torch
.
library
,
"infer_schema"
):
schema_str
=
torch
.
library
.
infer_schema
(
op_func
,
mutates_args
=
mutates_args
)
else
:
# for pytorch 2.4
import
torch._custom_op.impl
schema_str
=
torch
.
_custom_op
.
impl
.
infer_schema
(
op_func
,
mutates_args
)
my_lib
=
target_lib
or
sglang_lib
my_lib
.
define
(
op_name
+
schema_str
)
my_lib
.
impl
(
op_name
,
op_func
,
"CUDA"
)
if
fake_impl
is
not
None
:
my_lib
.
_register_fake
(
op_name
,
fake_impl
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment