Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
63051738
Unverified
Commit
63051738
authored
Jan 17, 2025
by
Chunyuan WU
Committed by
GitHub
Jan 16, 2025
Browse files
Enable CPU device on SGLang (#2806)
parent
a8ccacc8
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
376 additions
and
9 deletions
+376
-9
python/pyproject.toml
python/pyproject.toml
+6
-0
python/sglang/srt/configs/device_config.py
python/sglang/srt/configs/device_config.py
+1
-1
python/sglang/srt/layers/moe/fused_moe_native.py
python/sglang/srt/layers/moe/fused_moe_native.py
+69
-0
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+4
-1
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+26
-2
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+248
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-0
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+2
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+9
-3
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+3
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+4
-0
No files found.
python/pyproject.toml
View file @
63051738
...
@@ -40,6 +40,10 @@ srt_xpu = ["sglang[runtime_common]"]
...
@@ -40,6 +40,10 @@ srt_xpu = ["sglang[runtime_common]"]
#For Intel Gaudi(device : hpu) follow the installation guide
#For Intel Gaudi(device : hpu) follow the installation guide
#https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
#https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
srt_hpu
=
["sglang[runtime_common]"]
srt_hpu
=
["sglang[runtime_common]"]
# CPU: currently, there are no pre-built vllm wheels for CPU.
# To install vllm for CPU, please follow the instruction here:
# https://docs.vllm.ai/en/latest/getting_started/installation/cpu/index.html
srt_cpu
=
["sglang[runtime_common]
", "
torch
"]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
anthropic
=
["anthropic>=0.20.0"]
anthropic
=
["anthropic>=0.20.0"]
...
@@ -57,11 +61,13 @@ all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
...
@@ -57,11 +61,13 @@ all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
all_hip
=
["sglang[srt_hip]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
all_hip
=
["sglang[srt_hip]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
all_xpu
=
["sglang[srt_xpu]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
all_xpu
=
["sglang[srt_xpu]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
all_hpu
=
["sglang[srt_hpu]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
all_hpu
=
["sglang[srt_hpu]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
all_cpu
=
["sglang[srt_cpu]
", "
sglang
[openai]
", "
sglang
[anthropic]
", "
sglang
[litellm]"]
dev
=
["sglang[all]
", "
sglang
[test]"]
dev
=
["sglang[all]
", "
sglang
[test]"]
dev_hip
=
["sglang[all_hip]
", "
sglang
[test]"]
dev_hip
=
["sglang[all_hip]
", "
sglang
[test]"]
dev_xpu
=
["sglang[all_xpu]
", "
sglang
[test]"]
dev_xpu
=
["sglang[all_xpu]
", "
sglang
[test]"]
dev_hpu
=
["sglang[all_hpu]
", "
sglang
[test]"]
dev_hpu
=
["sglang[all_hpu]
", "
sglang
[test]"]
dev_cpu
=
["sglang[all_cpu]
", "
sglang
[test]"]
[project.urls]
[project.urls]
"Homepage"
=
"https://github.com/sgl-project/sglang"
"Homepage"
=
"https://github.com/sgl-project/sglang"
...
...
python/sglang/srt/configs/device_config.py
View file @
63051738
...
@@ -10,7 +10,7 @@ class DeviceConfig:
...
@@ -10,7 +10,7 @@ class DeviceConfig:
device
:
Optional
[
torch
.
device
]
device
:
Optional
[
torch
.
device
]
def
__init__
(
self
,
device
:
str
=
"cuda"
)
->
None
:
def
__init__
(
self
,
device
:
str
=
"cuda"
)
->
None
:
if
device
in
[
"cuda"
,
"xpu"
,
"hpu"
]:
if
device
in
[
"cuda"
,
"xpu"
,
"hpu"
,
"cpu"
]:
self
.
device_type
=
device
self
.
device_type
=
device
else
:
else
:
raise
RuntimeError
(
f
"Not supported device type:
{
device
}
"
)
raise
RuntimeError
(
f
"Not supported device type:
{
device
}
"
)
...
...
python/sglang/srt/layers/moe/fused_moe_native.py
View file @
63051738
...
@@ -8,6 +8,7 @@ from typing import Callable, Optional
...
@@ -8,6 +8,7 @@ from typing import Callable, Optional
import
torch
import
torch
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
select_experts
...
@@ -44,3 +45,71 @@ def fused_moe_forward_native(
...
@@ -44,3 +45,71 @@ def fused_moe_forward_native(
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
.
to
(
expert_outs
.
dtype
))
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
.
to
(
expert_outs
.
dtype
))
def
moe_forward_native
(
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
torch_native
=
True
,
)
# Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
len_experts
=
layer
.
num_experts
cnts
=
topk_ids
.
new_zeros
((
topk_ids
.
shape
[
0
],
len_experts
))
cnts
.
scatter_
(
1
,
topk_ids
.
to
(
torch
.
int64
),
1
)
tokens_per_expert
=
cnts
.
sum
(
dim
=
0
)
idxs
=
topk_ids
.
view
(
-
1
).
argsort
()
sorted_tokens
=
x
[
idxs
//
topk_ids
.
shape
[
1
]]
tokens_per_expert
=
tokens_per_expert
.
cpu
().
numpy
()
outputs
=
[]
start_idx
=
0
for
i
,
num_tokens
in
enumerate
(
tokens_per_expert
):
end_idx
=
start_idx
+
num_tokens
if
num_tokens
==
0
:
continue
tokens_for_this_expert
=
sorted_tokens
[
start_idx
:
end_idx
]
layer_w13_weight
=
layer
.
w13_weight
[
i
]
layer_w2_weight
=
layer
.
w2_weight
[
i
]
gate_up
=
F
.
linear
(
tokens_for_this_expert
,
layer_w13_weight
)
gate_up
=
SiluAndMul
()(
gate_up
)
expert_out
=
F
.
linear
(
gate_up
,
layer_w2_weight
)
outputs
.
append
(
expert_out
)
start_idx
=
end_idx
outs
=
torch
.
cat
(
outputs
,
dim
=
0
)
if
len
(
outputs
)
else
sorted_tokens
.
new_empty
(
0
)
new_x
=
torch
.
empty_like
(
outs
)
new_x
[
idxs
]
=
outs
final_out
=
(
new_x
.
view
(
*
topk_ids
.
shape
,
-
1
)
.
type
(
topk_weights
.
dtype
)
.
mul_
(
topk_weights
.
unsqueeze
(
dim
=-
1
))
.
sum
(
dim
=
1
)
.
type
(
new_x
.
dtype
)
)
return
final_out
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
63051738
...
@@ -19,7 +19,10 @@ from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
...
@@ -19,7 +19,10 @@ from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
is_hip_flag
=
False
is_hip_flag
=
False
if
not
is_hip
():
if
not
is_hip
():
if
torch
.
cuda
.
is_available
():
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
else
:
sgl_moe_align_block_size
=
None
is_hip_flag
=
False
is_hip_flag
=
False
else
:
else
:
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
63051738
...
@@ -13,6 +13,7 @@ from vllm.distributed import (
...
@@ -13,6 +13,7 @@ from vllm.distributed import (
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.custom_op_util
import
register_custom_op
from
sglang.srt.layers.custom_op_util
import
register_custom_op
from
sglang.srt.layers.moe.fused_moe_native
import
moe_forward_native
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizationConfig
,
...
@@ -185,8 +186,31 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -185,8 +186,31 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace
=
True
,
inplace
=
True
,
)
)
def
forward_cpu
(
self
,
*
args
,
**
kwargs
):
def
forward_cpu
(
raise
NotImplementedError
(
"The CPU backend currently does not support MoE."
)
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
moe_forward_native
(
layer
,
x
,
use_grouped_topk
,
top_k
,
router_logits
,
renormalize
,
topk_group
,
num_expert_group
,
custom_routing_function
,
correction_bias
,
)
def
forward_tpu
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
def
forward_tpu
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
raise
NotImplementedError
(
"The TPU backend currently does not support MoE."
)
raise
NotImplementedError
(
"The TPU backend currently does not support MoE."
)
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
63051738
...
@@ -15,6 +15,15 @@
...
@@ -15,6 +15,15 @@
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
vllm.model_executor.layers.rotary_embedding
import
(
RotaryEmbedding
,
_rotate_gptj
,
_rotate_neox
,
_yarn_find_correction_range
,
_yarn_linear_ramp_mask
,
get_rope
,
yarn_get_mscale
,
)
class
MRotaryEmbedding
:
class
MRotaryEmbedding
:
...
@@ -110,3 +119,242 @@ class MRotaryEmbedding:
...
@@ -110,3 +119,242 @@ class MRotaryEmbedding:
)
)
for
_
in
range
(
3
)
for
_
in
range
(
3
)
]
]
# TODO: in the DeepseekScalingRotaryEmbedding class defined in vllm,
# the device has been hard-coded to "cuda" in these two places:
# https://github.com/vllm-project/vllm/blob/8a1f938e6f02052df0f4953c149410605a2d56d8/vllm/model_executor/layers/rotary_embedding.py#L646
# https://github.com/vllm-project/vllm/blob/8a1f938e6f02052df0f4953c149410605a2d56d8/vllm/model_executor/layers/rotary_embedding.py#L665
# We port the related code to this file to make it compatible with the CPU version.
# We will add an optimized rotary embedding kernel for CPU and will remove the ported code then.
class
DeepseekScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
*
,
extrapolation_factor
:
float
=
1
,
attn_factor
:
float
=
1
,
beta_fast
:
int
=
32
,
beta_slow
:
int
=
1
,
mscale
:
float
=
1
,
mscale_all_dim
:
float
=
0
,
device
:
Optional
[
str
]
=
None
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
extrapolation_factor
=
extrapolation_factor
self
.
attn_factor
=
attn_factor
self
.
beta_fast
=
beta_fast
self
.
beta_slow
=
beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self
.
mscale
=
float
(
yarn_get_mscale
(
self
.
scaling_factor
,
float
(
mscale
))
/
yarn_get_mscale
(
self
.
scaling_factor
,
float
(
mscale_all_dim
))
*
attn_factor
)
self
.
device
=
device
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
pos_freqs
=
self
.
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
,
device
=
self
.
device
)
/
self
.
rotary_dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
scaling_factor
*
pos_freqs
)
low
,
high
=
_yarn_find_correction_range
(
self
.
beta_fast
,
self
.
beta_slow
,
self
.
rotary_dim
,
self
.
base
,
self
.
max_position_embeddings
,
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask
=
(
1
-
_yarn_linear_ramp_mask
(
low
,
high
,
self
.
rotary_dim
//
2
,
dtype
=
torch
.
float
)
)
*
self
.
extrapolation_factor
inv_freq
=
(
inv_freq_interpolation
*
(
1
-
inv_freq_mask
)
+
inv_freq_extrapolation
*
inv_freq_mask
)
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
scaling_factor
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
*
self
.
scaling_factor
,
device
=
self
.
device
,
dtype
=
torch
.
float32
,
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
*
self
.
mscale
sin
=
freqs
.
sin
()
*
self
.
mscale
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
print
(
"Cache shape"
,
cache
.
shape
)
return
cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward()."""
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
positions
.
device
)
cos_sin
=
self
.
cos_sin_cache
[
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
query_rot
=
query_rot
*
cos
+
rotate_fn
(
query_rot
)
*
sin
key_rot
=
key_rot
*
cos
+
rotate_fn
(
key_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
else
:
query
=
query_rot
key
=
key_rot
return
query
,
key
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
def
get_rope_cpu
(
head_size
:
int
,
rotary_dim
:
int
,
max_position
:
int
,
base
:
int
,
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
partial_rotary_factor
:
float
=
1.0
,
device
:
Optional
[
str
]
=
None
,
)
->
RotaryEmbedding
:
if
dtype
is
None
:
dtype
=
torch
.
get_default_dtype
()
if
rope_scaling
is
not
None
:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple
=
{
k
:
tuple
(
v
)
if
isinstance
(
v
,
list
)
else
v
for
k
,
v
in
rope_scaling
.
items
()
}
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
else
:
rope_scaling_args
=
None
if
partial_rotary_factor
<
1.0
:
rotary_dim
=
int
(
rotary_dim
*
partial_rotary_factor
)
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
rope_scaling_args
,
dtype
,
)
if
key
in
_ROPE_DICT
:
return
_ROPE_DICT
[
key
]
assert
rope_scaling
is
not
None
scaling_type
=
rope_scaling
[
"rope_type"
]
assert
(
scaling_type
==
"deepseek_yarn"
),
"Only deepseek_yarn is supported for CPU for now"
scaling_factor
=
rope_scaling
[
"factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
# assert max_position == original_max_position * scaling_factor
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
,
"mscale"
,
"mscale_all_dim"
,
)
}
extra_kwargs
[
"device"
]
=
device
rotary_emb
=
DeepseekScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
**
extra_kwargs
,
)
_ROPE_DICT
[
key
]
=
rotary_emb
return
rotary_emb
def
get_rope_wrapper
(
head_size
:
int
,
rotary_dim
:
int
,
max_position
:
int
,
base
:
int
,
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
partial_rotary_factor
:
float
=
1.0
,
device
:
Optional
[
str
]
=
None
,
):
if
device
!=
"cpu"
:
return
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
rope_scaling
,
dtype
,
partial_rotary_factor
,
)
return
get_rope_cpu
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
rope_scaling
,
dtype
,
partial_rotary_factor
,
device
,
)
python/sglang/srt/managers/schedule_batch.py
View file @
63051738
...
@@ -65,6 +65,7 @@ global_server_args_dict = {
...
@@ -65,6 +65,7 @@ global_server_args_dict = {
"enable_nan_detection"
:
ServerArgs
.
enable_nan_detection
,
"enable_nan_detection"
:
ServerArgs
.
enable_nan_detection
,
"enable_dp_attention"
:
ServerArgs
.
enable_dp_attention
,
"enable_dp_attention"
:
ServerArgs
.
enable_dp_attention
,
"enable_ep_moe"
:
ServerArgs
.
enable_ep_moe
,
"enable_ep_moe"
:
ServerArgs
.
enable_ep_moe
,
"device"
:
ServerArgs
.
device
,
}
}
...
...
python/sglang/srt/managers/scheduler.py
View file @
63051738
...
@@ -317,6 +317,8 @@ class Scheduler:
...
@@ -317,6 +317,8 @@ class Scheduler:
self
.
last_decode_stats_tic
=
time
.
time
()
self
.
last_decode_stats_tic
=
time
.
time
()
self
.
stream_interval
=
server_args
.
stream_interval
self
.
stream_interval
=
server_args
.
stream_interval
self
.
current_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
self
.
current_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
if
self
.
device
==
"cpu"
:
self
.
current_stream
.
synchronize
=
lambda
:
None
# No-op for CPU
# Session info
# Session info
self
.
sessions
:
Dict
[
str
,
Session
]
=
{}
self
.
sessions
:
Dict
[
str
,
Session
]
=
{}
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
63051738
...
@@ -82,6 +82,8 @@ class TpModelWorkerClient:
...
@@ -82,6 +82,8 @@ class TpModelWorkerClient:
self
.
forward_thread
.
start
()
self
.
forward_thread
.
start
()
self
.
parent_process
=
psutil
.
Process
().
parent
()
self
.
parent_process
=
psutil
.
Process
().
parent
()
self
.
scheduler_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
self
.
scheduler_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
if
self
.
device
==
"cpu"
:
self
.
scheduler_stream
.
synchronize
=
lambda
:
None
# No-op for CPU
def
get_worker_info
(
self
):
def
get_worker_info
(
self
):
return
self
.
worker
.
get_worker_info
()
return
self
.
worker
.
get_worker_info
()
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
63051738
...
@@ -106,6 +106,8 @@ class ModelRunner:
...
@@ -106,6 +106,8 @@ class ModelRunner:
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
and
not
self
.
server_args
.
disable_mla
and
not
self
.
server_args
.
disable_mla
):
):
# TODO: add MLA optimization on CPU
if
self
.
server_args
.
device
!=
"cpu"
:
logger
.
info
(
"MLA optimization is turned on. Use triton backend."
)
logger
.
info
(
"MLA optimization is turned on. Use triton backend."
)
self
.
server_args
.
attention_backend
=
"triton"
self
.
server_args
.
attention_backend
=
"triton"
...
@@ -164,6 +166,7 @@ class ModelRunner:
...
@@ -164,6 +166,7 @@ class ModelRunner:
"enable_nan_detection"
:
server_args
.
enable_nan_detection
,
"enable_nan_detection"
:
server_args
.
enable_nan_detection
,
"enable_dp_attention"
:
server_args
.
enable_dp_attention
,
"enable_dp_attention"
:
server_args
.
enable_dp_attention
,
"enable_ep_moe"
:
server_args
.
enable_ep_moe
,
"enable_ep_moe"
:
server_args
.
enable_ep_moe
,
"device"
:
server_args
.
device
,
}
}
)
)
...
@@ -221,6 +224,8 @@ class ModelRunner:
...
@@ -221,6 +224,8 @@ class ModelRunner:
backend
=
"gloo"
backend
=
"gloo"
elif
self
.
device
==
"hpu"
:
elif
self
.
device
==
"hpu"
:
backend
=
"hccl"
backend
=
"hccl"
elif
self
.
device
==
"cpu"
:
backend
=
"gloo"
if
not
self
.
server_args
.
enable_p2p_check
:
if
not
self
.
server_args
.
enable_p2p_check
:
monkey_patch_vllm_p2p_access_check
(
self
.
gpu_id
)
monkey_patch_vllm_p2p_access_check
(
self
.
gpu_id
)
...
@@ -269,6 +274,7 @@ class ModelRunner:
...
@@ -269,6 +274,7 @@ class ModelRunner:
)
)
# This can reduce thread conflicts and speed up weight loading.
# This can reduce thread conflicts and speed up weight loading.
if
self
.
device
!=
"cpu"
:
torch
.
set_num_threads
(
1
)
torch
.
set_num_threads
(
1
)
if
self
.
device
==
"cuda"
:
if
self
.
device
==
"cuda"
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
63051738
...
@@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
...
@@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
normalize_e4m3fn_to_e4m3fnuz
,
normalize_e4m3fn_to_e4m3fnuz
,
)
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope_wrapper
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
...
@@ -271,13 +272,14 @@ class DeepseekV2Attention(nn.Module):
...
@@ -271,13 +272,14 @@ class DeepseekV2Attention(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
rope_scaling
[
"rope_type"
]
=
"deepseek_yarn"
rope_scaling
[
"rope_type"
]
=
"deepseek_yarn"
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
_wrapper
(
qk_rope_head_dim
,
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
False
,
is_neox_style
=
False
,
device
=
global_server_args_dict
[
"device"
],
)
)
if
rope_scaling
:
if
rope_scaling
:
...
...
python/sglang/srt/server_args.py
View file @
63051738
...
@@ -392,7 +392,7 @@ class ServerArgs:
...
@@ -392,7 +392,7 @@ class ServerArgs:
"--device"
,
"--device"
,
type
=
str
,
type
=
str
,
default
=
"cuda"
,
default
=
"cuda"
,
choices
=
[
"cuda"
,
"xpu"
,
"hpu"
],
choices
=
[
"cuda"
,
"xpu"
,
"hpu"
,
"cpu"
],
help
=
"The device type."
,
help
=
"The device type."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
...
...
python/sglang/srt/utils.py
View file @
63051738
...
@@ -223,6 +223,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
...
@@ -223,6 +223,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
free_gpu_memory
,
total_gpu_memory
=
torch
.
hpu
.
mem_get_info
()
free_gpu_memory
,
total_gpu_memory
=
torch
.
hpu
.
mem_get_info
()
elif
device
==
"cpu"
:
# TODO: rename the variables in the current function to be not GPU specific
free_gpu_memory
=
psutil
.
virtual_memory
().
available
if
distributed
:
if
distributed
:
tensor
=
torch
.
tensor
(
free_gpu_memory
,
dtype
=
torch
.
float32
).
to
(
tensor
=
torch
.
tensor
(
free_gpu_memory
,
dtype
=
torch
.
float32
).
to
(
torch
.
device
(
device
,
gpu_id
)
torch
.
device
(
device
,
gpu_id
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment