Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
eedc12e1
Unverified
Commit
eedc12e1
authored
Jul 21, 2024
by
Liangsheng Yin
Committed by
GitHub
Jul 21, 2024
Browse files
Support Deepseek MoE Model (#689)
parent
5a4ef2b5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
519 additions
and
23 deletions
+519
-23
python/sglang/srt/managers/controller/cuda_graph_runner.py
python/sglang/srt/managers/controller/cuda_graph_runner.py
+38
-19
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+4
-3
python/sglang/srt/models/deepseek.py
python/sglang/srt/models/deepseek.py
+430
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+46
-0
No files found.
python/sglang/srt/managers/controller/cuda_graph_runner.py
View file @
eedc12e1
"""Run the model with cuda graph."""
"""Run the model with cuda graph."""
import
bisect
import
bisect
from
contextlib
import
contextmanager
import
torch
import
torch
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
...
@@ -15,9 +16,10 @@ from sglang.srt.managers.controller.infer_batch import (
...
@@ -15,9 +16,10 @@ from sglang.srt.managers.controller.infer_batch import (
InputMetadata
,
InputMetadata
,
init_flashinfer_args
,
init_flashinfer_args
,
)
)
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
=
False
):
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
=
False
):
for
sub
in
model
.
_modules
.
values
():
for
sub
in
model
.
_modules
.
values
():
if
isinstance
(
sub
,
CustomOp
):
if
isinstance
(
sub
,
CustomOp
):
if
reverse
:
if
reverse
:
...
@@ -28,13 +30,26 @@ def _to_torch(model: torch.nn.Module, reverse=False):
...
@@ -28,13 +30,26 @@ def _to_torch(model: torch.nn.Module, reverse=False):
_to_torch
(
sub
,
reverse
)
_to_torch
(
sub
,
reverse
)
def
get_forward
(
model
:
torch
.
nn
.
Module
,
use_torch
:
bool
):
@
contextmanager
if
use_torch
:
def
patch_model
(
_to_torch
(
model
,
reverse
=
False
)
model
:
torch
.
nn
.
Module
,
use_compile
:
bool
,
tp_group
:
"GroupCoordinator"
return
torch
.
compile
(
model
.
forward
,
mode
=
"max-autotune-no-cudagraphs"
)
):
else
:
backup_ca_comm
=
None
_to_torch
(
model
,
reverse
=
True
)
return
model
.
forward
try
:
if
use_compile
:
_to_torch
(
model
)
monkey_patch_vllm_all_gather
()
backup_ca_comm
=
tp_group
.
ca_comm
tp_group
.
ca_comm
=
None
yield
torch
.
compile
(
model
.
forward
,
mode
=
"max-autotune-no-cudagraphs"
)
else
:
yield
model
.
forward
finally
:
if
use_compile
:
_to_torch
(
model
,
reverse
=
True
)
monkey_patch_vllm_all_gather
(
reverse
=
True
)
tp_group
.
ca_comm
=
backup_ca_comm
class
CudaGraphRunner
:
class
CudaGraphRunner
:
...
@@ -86,17 +101,21 @@ class CudaGraphRunner:
...
@@ -86,17 +101,21 @@ class CudaGraphRunner:
with
graph_capture
()
as
graph_capture_context
:
with
graph_capture
()
as
graph_capture_context
:
self
.
stream
=
graph_capture_context
.
stream
self
.
stream
=
graph_capture_context
.
stream
for
bs
in
batch_size_list
:
for
bs
in
batch_size_list
:
forward
=
get_forward
(
self
.
model_runner
.
model
,
bs
in
self
.
compile_bs
)
with
patch_model
(
(
self
.
model_runner
.
model
,
graph
,
bs
in
self
.
compile_bs
,
input_buffers
,
self
.
model_runner
.
tp_group
,
output_buffers
,
)
as
forward
:
flashinfer_handler
,
(
)
=
self
.
capture_one_batch_size
(
bs
,
forward
)
graph
,
self
.
graphs
[
bs
]
=
graph
input_buffers
,
self
.
input_buffers
[
bs
]
=
input_buffers
output_buffers
,
self
.
output_buffers
[
bs
]
=
output_buffers
flashinfer_handler
,
self
.
flashinfer_handlers
[
bs
]
=
flashinfer_handler
)
=
self
.
capture_one_batch_size
(
bs
,
forward
)
self
.
graphs
[
bs
]
=
graph
self
.
input_buffers
[
bs
]
=
input_buffers
self
.
output_buffers
[
bs
]
=
output_buffers
self
.
flashinfer_handlers
[
bs
]
=
flashinfer_handler
def
capture_one_batch_size
(
self
,
bs
,
forward
):
def
capture_one_batch_size
(
self
,
bs
,
forward
):
graph
=
torch
.
cuda
.
CUDAGraph
()
graph
=
torch
.
cuda
.
CUDAGraph
()
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
eedc12e1
...
@@ -22,7 +22,6 @@ from vllm.distributed import (
...
@@ -22,7 +22,6 @@ from vllm.distributed import (
init_distributed_environment
,
init_distributed_environment
,
initialize_model_parallel
,
initialize_model_parallel
,
)
)
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
...
@@ -241,7 +240,9 @@ class ModelRunner:
...
@@ -241,7 +240,9 @@ class ModelRunner:
self
.
cuda_graph_runner
=
None
self
.
cuda_graph_runner
=
None
return
return
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Capture cuda graph begin."
)
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Capture cuda graph begin. This can take up to several minutes."
)
batch_size_list
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
17
)]
batch_size_list
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
17
)]
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
,
self
,
...
@@ -252,7 +253,7 @@ class ModelRunner:
...
@@ -252,7 +253,7 @@ class ModelRunner:
self
.
cuda_graph_runner
.
capture
(
batch_size_list
)
self
.
cuda_graph_runner
.
capture
(
batch_size_list
)
except
RuntimeError
as
e
:
except
RuntimeError
as
e
:
raise
Exception
(
raise
Exception
(
f
"Capture cuda graph failed
{
e
}
. Possible solutions:
\n
"
f
"Capture cuda graph failed
:
{
e
}
. Possible solutions:
\n
"
f
"1. disable cuda graph by --disable-cuda-graph
\n
"
f
"1. disable cuda graph by --disable-cuda-graph
\n
"
f
"2. set --mem-fraction-static to a smaller value
\n
"
f
"2. set --mem-fraction-static to a smaller value
\n
"
f
"Open an issue on GitHub with reproducible scripts if you need help.
\n
"
f
"Open an issue on GitHub with reproducible scripts if you need help.
\n
"
...
...
python/sglang/srt/models/deepseek.py
0 → 100644
View file @
eedc12e1
# Adapted from:
# https://github.com/vllm-project/vllm/blob/14f91fe67c2342f2fe859dc6a5c40810df0e1c61/vllm/model_executor/models/deepseek.py
"""Inference-only Deepseek model."""
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.controller.infer_batch
import
InputMetadata
class
DeepseekMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
DeepseekMoE
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
n_routed_experts
=
config
.
n_routed_experts
self
.
top_k
=
config
.
num_experts_per_tok
if
self
.
tp_size
>
self
.
n_routed_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
self
.
n_routed_experts
}
."
)
self
.
experts
=
nn
.
ModuleList
(
[
DeepseekMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
False
,
)
for
idx
in
range
(
self
.
n_routed_experts
)
]
)
self
.
pack_params
()
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
n_routed_experts
,
bias
=
False
,
quant_config
=
None
)
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
self
.
shared_experts
=
DeepseekMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
reduce_results
=
False
,
)
def
pack_params
(
self
):
w1
=
[]
w2
=
[]
for
expert
in
self
.
experts
:
w1
.
append
(
expert
.
gate_up_proj
.
weight
)
w2
.
append
(
expert
.
down_proj
.
weight
)
self
.
w1
=
torch
.
_utils
.
_flatten_dense_tensors
(
w1
)
w1s
=
torch
.
_utils
.
_unflatten_dense_tensors
(
self
.
w1
,
w1
)
for
data
,
param
in
zip
(
w1s
,
w1
):
param
.
data
=
data
self
.
w1
=
self
.
w1
.
view
(
len
(
w1
),
*
w1s
[
0
].
shape
)
self
.
w2
=
torch
.
_utils
.
_flatten_dense_tensors
(
w2
)
w2s
=
torch
.
_utils
.
_unflatten_dense_tensors
(
self
.
w2
,
w2
)
for
data
,
param
in
zip
(
w2s
,
w2
):
param
.
data
=
data
self
.
w2
=
self
.
w2
.
view
(
len
(
w2
),
*
w2s
[
0
].
shape
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
config
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
w1
,
self
.
w2
,
router_logits
,
self
.
top_k
,
renormalize
=
self
.
config
.
norm_topk_prob
,
inplace
=
True
,
)
if
self
.
config
.
n_shared_experts
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
class
DeepseekAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
layer_id
:
int
=
0
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_id
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
input_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
DeepseekDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
layer_id
:
int
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
self_attn
=
DeepseekAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
layer_id
=
layer_id
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
)
if
(
config
.
n_routed_experts
is
not
None
and
layer_id
>=
config
.
first_k_dense_replace
and
layer_id
%
config
.
moe_layer_freq
==
0
):
self
.
mlp
=
DeepseekMoE
(
config
=
config
,
quant_config
=
quant_config
)
else
:
self
.
mlp
=
DeepseekMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
input_metadata
=
input_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
DeepseekModel
(
nn
.
Module
):
fall_back_to_pt_during_load
=
False
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
(
[
DeepseekDecoderLayer
(
config
,
layer_id
,
cache_config
,
quant_config
=
quant_config
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
input_metadata
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
DeepseekForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
DeepseekModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
(
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
(
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
EntryClass
=
DeepseekForCausalLM
python/sglang/srt/server.py
View file @
eedc12e1
...
@@ -167,7 +167,7 @@ def _set_torch_compile_config():
...
@@ -167,7 +167,7 @@ def _set_torch_compile_config():
torch
.
_inductor
.
config
.
fx_graph_cache
=
True
# Experimental feature to reduce compilation times, will be on by default in future
torch
.
_inductor
.
config
.
fx_graph_cache
=
True
# Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
# FIXME: tmp workaround
torch
.
_dynamo
.
config
.
accumulated_cache_size_limit
=
128
torch
.
_dynamo
.
config
.
accumulated_cache_size_limit
=
256
def
launch_server
(
def
launch_server
(
...
...
python/sglang/srt/utils.py
View file @
eedc12e1
...
@@ -411,6 +411,52 @@ def monkey_patch_vllm_dummy_weight_loader():
...
@@ -411,6 +411,52 @@ def monkey_patch_vllm_dummy_weight_loader():
setattr
(
DummyModelLoader
,
"load_model"
,
load_model
)
setattr
(
DummyModelLoader
,
"load_model"
,
load_model
)
vllm_all_gather_backup
=
None
def
monkey_patch_vllm_all_gather
(
reverse
:
bool
=
False
):
"""Monkey patch all-gather to remove in-place operations."""
from
torch.distributed
import
_functional_collectives
as
funcol
from
vllm.distributed.parallel_state
import
GroupCoordinator
global
vllm_all_gather_backup
if
vllm_all_gather_backup
is
None
:
vllm_all_gather_backup
=
GroupCoordinator
.
all_gather
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
world_size
=
self
.
world_size
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
assert
(
-
input_
.
dim
()
<=
dim
<
input_
.
dim
()
),
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
input_size
=
input_
.
size
()
# Allocate output tensor.
output_tensor
=
torch
.
empty
(
(
world_size
,)
+
input_size
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
output_tensor
=
funcol
.
all_gather_tensor
(
input_
,
gather_dim
=
0
,
group
=
self
.
device_group
).
view
((
world_size
,)
+
input_size
)
# Reshape
output_tensor
=
output_tensor
.
movedim
(
0
,
dim
)
output_tensor
=
output_tensor
.
reshape
(
input_size
[:
dim
]
+
(
world_size
*
input_size
[
dim
],)
+
input_size
[
dim
+
1
:]
)
return
output_tensor
if
reverse
:
setattr
(
GroupCoordinator
,
"all_gather"
,
vllm_all_gather_backup
)
else
:
setattr
(
GroupCoordinator
,
"all_gather"
,
all_gather
)
API_KEY_HEADER_NAME
=
"X-API-Key"
API_KEY_HEADER_NAME
=
"X-API-Key"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment