Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
11383cec
Unverified
Commit
11383cec
authored
Apr 30, 2025
by
Ying Sheng
Committed by
GitHub
Apr 30, 2025
Browse files
[PP] Add pipeline parallelism (#5724)
parent
e97e57e6
Changes
25
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
894 additions
and
303 deletions
+894
-303
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+2
-0
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+36
-19
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+5
-2
python/sglang/srt/layers/utils.py
python/sglang/srt/layers/utils.py
+35
-0
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+52
-34
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+25
-15
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+262
-59
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+1
-1
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+50
-16
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+9
-3
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+70
-36
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+67
-18
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+31
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+101
-54
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+92
-30
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+2
-1
python/sglang/srt/models/llama_eagle.py
python/sglang/srt/models/llama_eagle.py
+4
-1
python/sglang/srt/models/llama_eagle3.py
python/sglang/srt/models/llama_eagle3.py
+4
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+43
-10
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+3
-2
No files found.
python/sglang/bench_one_batch.py
View file @
11383cec
...
@@ -154,6 +154,8 @@ def load_model(server_args, port_args, tp_rank):
...
@@ -154,6 +154,8 @@ def load_model(server_args, port_args, tp_rank):
gpu_id
=
tp_rank
,
gpu_id
=
tp_rank
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
tp_size
=
server_args
.
tp_size
,
tp_size
=
server_args
.
tp_size
,
pp_rank
=
0
,
pp_size
=
1
,
nccl_port
=
port_args
.
nccl_port
,
nccl_port
=
port_args
.
nccl_port
,
server_args
=
server_args
,
server_args
=
server_args
,
)
)
...
...
python/sglang/srt/entrypoints/engine.py
View file @
11383cec
...
@@ -126,7 +126,6 @@ class Engine(EngineBase):
...
@@ -126,7 +126,6 @@ class Engine(EngineBase):
server_args
=
server_args
,
server_args
=
server_args
,
port_args
=
port_args
,
port_args
=
port_args
,
)
)
self
.
server_args
=
server_args
self
.
server_args
=
server_args
self
.
tokenizer_manager
=
tokenizer_manager
self
.
tokenizer_manager
=
tokenizer_manager
self
.
scheduler_info
=
scheduler_info
self
.
scheduler_info
=
scheduler_info
...
@@ -301,7 +300,6 @@ class Engine(EngineBase):
...
@@ -301,7 +300,6 @@ class Engine(EngineBase):
internal_states
=
loop
.
run_until_complete
(
internal_states
=
loop
.
run_until_complete
(
self
.
tokenizer_manager
.
get_internal_state
()
self
.
tokenizer_manager
.
get_internal_state
()
)
)
return
{
return
{
**
dataclasses
.
asdict
(
self
.
tokenizer_manager
.
server_args
),
**
dataclasses
.
asdict
(
self
.
tokenizer_manager
.
server_args
),
**
self
.
scheduler_info
,
**
self
.
scheduler_info
,
...
@@ -520,25 +518,44 @@ def _launch_subprocesses(
...
@@ -520,25 +518,44 @@ def _launch_subprocesses(
)
)
scheduler_pipe_readers
=
[]
scheduler_pipe_readers
=
[]
tp_size_per_node
=
server_args
.
tp_size
//
server_args
.
nnodes
nnodes_per_tp_group
=
max
(
server_args
.
nnodes
//
server_args
.
pp_size
,
1
)
tp_size_per_node
=
server_args
.
tp_size
//
nnodes_per_tp_group
tp_rank_range
=
range
(
tp_rank_range
=
range
(
tp_size_per_node
*
server_args
.
node_rank
,
tp_size_per_node
*
(
server_args
.
node_rank
%
nnodes_per_tp_group
)
,
tp_size_per_node
*
(
server_args
.
node_rank
+
1
),
tp_size_per_node
*
(
server_args
.
node_rank
%
nnodes_per_tp_group
+
1
),
)
)
for
tp_rank
in
tp_rank_range
:
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
pp_size_per_node
=
max
(
server_args
.
pp_size
//
server_args
.
nnodes
,
1
)
gpu_id
=
(
pp_rank_range
=
range
(
server_args
.
base_gpu_id
pp_size_per_node
*
(
server_args
.
node_rank
//
nnodes_per_tp_group
),
+
(
tp_rank
%
tp_size_per_node
)
*
server_args
.
gpu_id_step
pp_size_per_node
*
(
server_args
.
node_rank
//
nnodes_per_tp_group
+
1
),
)
)
proc
=
mp
.
Process
(
target
=
run_scheduler_process
,
for
pp_rank
in
pp_rank_range
:
args
=
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
None
,
writer
),
for
tp_rank
in
tp_rank_range
:
)
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
with
memory_saver_adapter
.
configure_subprocess
():
gpu_id
=
(
proc
.
start
()
server_args
.
base_gpu_id
scheduler_procs
.
append
(
proc
)
+
((
pp_rank
%
pp_size_per_node
)
*
tp_size_per_node
)
scheduler_pipe_readers
.
append
(
reader
)
+
(
tp_rank
%
tp_size_per_node
)
*
server_args
.
gpu_id_step
)
proc
=
mp
.
Process
(
target
=
run_scheduler_process
,
args
=
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
pp_rank
,
None
,
writer
,
),
)
with
memory_saver_adapter
.
configure_subprocess
():
proc
.
start
()
scheduler_procs
.
append
(
proc
)
scheduler_pipe_readers
.
append
(
reader
)
else
:
else
:
# Launch the data parallel controller
# Launch the data parallel controller
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
...
...
python/sglang/srt/layers/dp_attention.py
View file @
11383cec
...
@@ -43,6 +43,7 @@ def initialize_dp_attention(
...
@@ -43,6 +43,7 @@ def initialize_dp_attention(
tp_rank
:
int
,
tp_rank
:
int
,
tp_size
:
int
,
tp_size
:
int
,
dp_size
:
int
,
dp_size
:
int
,
pp_size
:
int
,
):
):
global
_ATTN_TP_GROUP
,
_ATTN_TP_RANK
,
_ATTN_TP_SIZE
,
_DP_RANK
,
_DP_SIZE
global
_ATTN_TP_GROUP
,
_ATTN_TP_RANK
,
_ATTN_TP_SIZE
,
_DP_RANK
,
_DP_SIZE
...
@@ -53,17 +54,19 @@ def initialize_dp_attention(
...
@@ -53,17 +54,19 @@ def initialize_dp_attention(
)
)
if
enable_dp_attention
:
if
enable_dp_attention
:
local_rank
=
tp_rank
%
(
tp_size
//
dp_size
)
_DP_SIZE
=
dp_size
_DP_SIZE
=
dp_size
else
:
else
:
local_rank
=
tp_rank
_DP_SIZE
=
1
_DP_SIZE
=
1
tp_group
=
get_tp_group
()
tp_group
=
get_tp_group
()
_ATTN_TP_GROUP
=
GroupCoordinator
(
_ATTN_TP_GROUP
=
GroupCoordinator
(
[
[
list
(
range
(
head
,
head
+
_ATTN_TP_SIZE
))
list
(
range
(
head
,
head
+
_ATTN_TP_SIZE
))
for
head
in
range
(
0
,
tp_size
,
_ATTN_TP_SIZE
)
for
head
in
range
(
0
,
pp_size
*
tp_size
,
_ATTN_TP_SIZE
)
],
],
tp_group
.
local_rank
,
local_rank
,
torch
.
distributed
.
get_backend
(
tp_group
.
device_group
),
torch
.
distributed
.
get_backend
(
tp_group
.
device_group
),
SYNC_TOKEN_IDS_ACROSS_TP
,
SYNC_TOKEN_IDS_ACROSS_TP
,
False
,
False
,
...
...
python/sglang/srt/layers/utils.py
0 → 100644
View file @
11383cec
import
logging
import
re
import
torch
logger
=
logging
.
getLogger
(
__name__
)
def
get_layer_id
(
weight_name
):
# example weight name: model.layers.10.self_attn.qkv_proj.weight
match
=
re
.
search
(
r
"layers\.(\d+)\."
,
weight_name
)
if
match
:
return
int
(
match
.
group
(
1
))
return
None
class
PPMissingLayer
(
torch
.
nn
.
Identity
):
# Adapted from
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
"""
A placeholder layer for missing layers in a pipeline parallel model.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
()
self
.
return_tuple
=
kwargs
.
get
(
"return_tuple"
,
False
)
def
forward
(
self
,
*
args
,
**
kwargs
):
"""
Return the first arg from args or the first value from kwargs.
Wraps the input in a tuple if `self.return_tuple` is True.
"""
input
=
args
[
0
]
if
args
else
next
(
iter
(
kwargs
.
values
()))
return
(
input
,)
if
self
.
return_tuple
else
input
python/sglang/srt/managers/data_parallel_controller.py
View file @
11383cec
...
@@ -181,44 +181,62 @@ class DataParallelController:
...
@@ -181,44 +181,62 @@ class DataParallelController:
enable
=
server_args
.
enable_memory_saver
enable
=
server_args
.
enable_memory_saver
)
)
# Launch tensor parallel scheduler processes
scheduler_pipe_readers
=
[]
scheduler_pipe_readers
=
[]
tp_size_per_node
=
server_args
.
tp_size
//
server_args
.
nnodes
nnodes_per_tp_group
=
max
(
server_args
.
nnodes
//
server_args
.
pp_size
,
1
)
tp_size_per_node
=
server_args
.
tp_size
//
nnodes_per_tp_group
tp_rank_range
=
range
(
tp_rank_range
=
range
(
tp_size_per_node
*
server_args
.
node_rank
,
tp_size_per_node
*
(
server_args
.
node_rank
%
nnodes_per_tp_group
),
tp_size_per_node
*
(
server_args
.
node_rank
+
1
),
tp_size_per_node
*
(
server_args
.
node_rank
%
nnodes_per_tp_group
+
1
),
)
pp_size_per_node
=
max
(
server_args
.
pp_size
//
server_args
.
nnodes
,
1
)
pp_rank_range
=
range
(
pp_size_per_node
*
(
server_args
.
node_rank
//
nnodes_per_tp_group
),
pp_size_per_node
*
(
server_args
.
node_rank
//
nnodes_per_tp_group
+
1
),
)
)
for
tp_rank
in
tp_rank_range
:
rank_port_args
=
port_args
for
pp_rank
in
pp_rank_range
:
for
tp_rank
in
tp_rank_range
:
if
server_args
.
enable_dp_attention
:
rank_port_args
=
port_args
# dp attention has different sharding logic
_
,
_
,
dp_rank
=
compute_dp_attention_world_info
(
if
server_args
.
enable_dp_attention
:
server_args
.
enable_dp_attention
,
# dp attention has different sharding logic
tp_rank
,
_
,
_
,
dp_rank
=
compute_dp_attention_world_info
(
server_args
.
tp_size
,
server_args
.
enable_dp_attention
,
server_args
.
dp_size
,
tp_rank
,
server_args
.
tp_size
,
server_args
.
dp_size
,
)
# compute zmq ports for this dp rank
rank_port_args
=
PortArgs
.
init_new
(
server_args
,
dp_rank
)
# Data parallelism resues the tensor parallelism group,
# so all dp ranks should use the same nccl port.
rank_port_args
.
nccl_port
=
port_args
.
nccl_port
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
gpu_id
=
(
server_args
.
base_gpu_id
+
base_gpu_id
+
((
pp_rank
%
pp_size_per_node
)
*
tp_size_per_node
)
+
(
tp_rank
%
tp_size_per_node
)
*
server_args
.
gpu_id_step
)
)
# compute zmq ports for this dp rank
proc
=
mp
.
Process
(
rank_port_args
=
PortArgs
.
init_new
(
server_args
,
dp_rank
)
target
=
run_scheduler_process
,
# Data parallelism resues the tensor parallelism group,
args
=
(
# so all dp ranks should use the same nccl port.
server_args
,
rank_port_args
.
nccl_port
=
port_args
.
nccl_port
rank_port_args
,
gpu_id
,
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
tp_rank
,
gpu_id
=
(
pp_rank
,
server_args
.
base_gpu_id
dp_rank
,
+
base_gpu_id
writer
,
+
(
tp_rank
%
tp_size_per_node
)
*
server_args
.
gpu_id_step
),
)
)
proc
=
mp
.
Process
(
with
memory_saver_adapter
.
configure_subprocess
():
target
=
run_scheduler_process
,
proc
.
start
()
args
=
(
server_args
,
rank_port_args
,
gpu_id
,
tp_rank
,
dp_rank
,
writer
),
self
.
scheduler_procs
.
append
(
proc
)
)
scheduler_pipe_readers
.
append
(
reader
)
with
memory_saver_adapter
.
configure_subprocess
():
proc
.
start
()
self
.
scheduler_procs
.
append
(
proc
)
scheduler_pipe_readers
.
append
(
reader
)
# Wait for model to finish loading
# Wait for model to finish loading
scheduler_info
=
[]
scheduler_info
=
[]
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
11383cec
...
@@ -66,23 +66,24 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
...
@@ -66,23 +66,24 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access
# Put some global args for easy access
global_server_args_dict
=
{
global_server_args_dict
=
{
"attention_backend"
:
ServerArgs
.
attention_backend
,
"attention_backend"
:
ServerArgs
.
attention_backend
,
"sampling_backend"
:
ServerArgs
.
sampling_backend
,
"chunked_prefill_size"
:
ServerArgs
.
chunked_prefill_size
,
"triton_attention_reduce_in_fp32"
:
ServerArgs
.
triton_attention_reduce_in_fp32
,
"torchao_config"
:
ServerArgs
.
torchao_config
,
"enable_nan_detection"
:
ServerArgs
.
enable_nan_detection
,
"enable_dp_attention"
:
ServerArgs
.
enable_dp_attention
,
"enable_ep_moe"
:
ServerArgs
.
enable_ep_moe
,
"enable_deepep_moe"
:
ServerArgs
.
enable_deepep_moe
,
"deepep_mode"
:
ServerArgs
.
deepep_mode
,
"deepep_mode"
:
ServerArgs
.
deepep_mode
,
"device"
:
ServerArgs
.
device
,
"device"
:
ServerArgs
.
device
,
"speculative_accept_threshold_single"
:
ServerArgs
.
speculative_accept_threshold_single
,
"disable_chunked_prefix_cache"
:
ServerArgs
.
disable_chunked_prefix_cache
,
"speculative_accept_threshold_acc"
:
ServerArgs
.
speculative_accept_threshold_acc
,
"disable_radix_cache"
:
ServerArgs
.
disable_radix_cache
,
"disable_radix_cache"
:
ServerArgs
.
disable_radix_cache
,
"enable_deepep_moe"
:
ServerArgs
.
enable_deepep_moe
,
"enable_dp_attention"
:
ServerArgs
.
enable_dp_attention
,
"enable_ep_moe"
:
ServerArgs
.
enable_ep_moe
,
"enable_nan_detection"
:
ServerArgs
.
enable_nan_detection
,
"flashinfer_mla_disable_ragged"
:
ServerArgs
.
flashinfer_mla_disable_ragged
,
"flashinfer_mla_disable_ragged"
:
ServerArgs
.
flashinfer_mla_disable_ragged
,
"max_micro_batch_size"
:
ServerArgs
.
max_micro_batch_size
,
"moe_dense_tp_size"
:
ServerArgs
.
moe_dense_tp_size
,
"moe_dense_tp_size"
:
ServerArgs
.
moe_dense_tp_size
,
"chunked_prefill_size"
:
ServerArgs
.
chunked_prefill_size
,
"n_share_experts_fusion"
:
ServerArgs
.
n_share_experts_fusion
,
"n_share_experts_fusion"
:
ServerArgs
.
n_share_experts_fusion
,
"disable_chunked_prefix_cache"
:
ServerArgs
.
disable_chunked_prefix_cache
,
"sampling_backend"
:
ServerArgs
.
sampling_backend
,
"speculative_accept_threshold_acc"
:
ServerArgs
.
speculative_accept_threshold_acc
,
"speculative_accept_threshold_single"
:
ServerArgs
.
speculative_accept_threshold_single
,
"torchao_config"
:
ServerArgs
.
torchao_config
,
"triton_attention_reduce_in_fp32"
:
ServerArgs
.
triton_attention_reduce_in_fp32
,
}
}
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -728,6 +729,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -728,6 +729,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Events
# Events
launch_done
:
Optional
[
threading
.
Event
]
=
None
launch_done
:
Optional
[
threading
.
Event
]
=
None
# For chunked prefill in PP
chunked_req
:
Optional
[
Req
]
=
None
# Sampling info
# Sampling info
sampling_info
:
SamplingBatchInfo
=
None
sampling_info
:
SamplingBatchInfo
=
None
next_batch_sampling_info
:
SamplingBatchInfo
=
None
next_batch_sampling_info
:
SamplingBatchInfo
=
None
...
@@ -761,7 +765,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -761,7 +765,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# For extend and mixed chunekd prefill
# For extend and mixed chunekd prefill
prefix_lens
:
List
[
int
]
=
None
prefix_lens
:
List
[
int
]
=
None
extend_lens
:
List
[
int
]
=
None
extend_lens
:
List
[
int
]
=
None
extend_num_tokens
:
int
=
None
extend_num_tokens
:
Optional
[
int
]
=
None
decoding_reqs
:
List
[
Req
]
=
None
decoding_reqs
:
List
[
Req
]
=
None
extend_logprob_start_lens
:
List
[
int
]
=
None
extend_logprob_start_lens
:
List
[
int
]
=
None
# It comes empty list if logprob is not required.
# It comes empty list if logprob is not required.
...
@@ -803,6 +807,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -803,6 +807,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
enable_overlap
:
bool
,
enable_overlap
:
bool
,
spec_algorithm
:
SpeculativeAlgorithm
,
spec_algorithm
:
SpeculativeAlgorithm
,
enable_custom_logit_processor
:
bool
,
enable_custom_logit_processor
:
bool
,
chunked_req
:
Optional
[
Req
]
=
None
,
):
):
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
)
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
)
...
@@ -820,6 +825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -820,6 +825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
spec_algorithm
=
spec_algorithm
,
spec_algorithm
=
spec_algorithm
,
enable_custom_logit_processor
=
enable_custom_logit_processor
,
enable_custom_logit_processor
=
enable_custom_logit_processor
,
return_hidden_states
=
any
(
req
.
return_hidden_states
for
req
in
reqs
),
return_hidden_states
=
any
(
req
.
return_hidden_states
for
req
in
reqs
),
chunked_req
=
chunked_req
,
)
)
def
batch_size
(
self
):
def
batch_size
(
self
):
...
@@ -1236,7 +1242,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1236,7 +1242,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def
retract_decode
(
self
,
server_args
:
ServerArgs
):
def
retract_decode
(
self
,
server_args
:
ServerArgs
):
"""Retract the decoding requests when there is not enough memory."""
"""Retract the decoding requests when there is not enough memory."""
sorted_indices
=
[
i
for
i
in
range
(
len
(
self
.
reqs
))
]
sorted_indices
=
list
(
range
(
len
(
self
.
reqs
))
)
# TODO(lsyin): improve retraction policy for radix cache
# TODO(lsyin): improve retraction policy for radix cache
# For spec decoding, filter_batch API can only filter
# For spec decoding, filter_batch API can only filter
...
@@ -1413,15 +1419,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1413,15 +1419,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def
filter_batch
(
def
filter_batch
(
self
,
self
,
chunked_req_to_exclude
:
Optional
[
Req
]
=
None
,
chunked_req_to_exclude
:
Optional
[
Union
[
Req
,
List
[
Req
]]
]
=
None
,
keep_indices
:
Optional
[
List
[
int
]]
=
None
,
keep_indices
:
Optional
[
List
[
int
]]
=
None
,
):
):
if
keep_indices
is
None
:
if
keep_indices
is
None
:
if
isinstance
(
chunked_req_to_exclude
,
Req
):
chunked_req_to_exclude
=
[
chunked_req_to_exclude
]
elif
chunked_req_to_exclude
is
None
:
chunked_req_to_exclude
=
[]
keep_indices
=
[
keep_indices
=
[
i
i
for
i
in
range
(
len
(
self
.
reqs
))
for
i
in
range
(
len
(
self
.
reqs
))
if
not
self
.
reqs
[
i
].
finished
()
if
not
self
.
reqs
[
i
].
finished
()
and
self
.
reqs
[
i
]
i
s
not
chunked_req_to_exclude
and
not
self
.
reqs
[
i
]
i
n
chunked_req_to_exclude
]
]
if
keep_indices
is
None
or
len
(
keep_indices
)
==
0
:
if
keep_indices
is
None
or
len
(
keep_indices
)
==
0
:
...
...
python/sglang/srt/managers/scheduler.py
View file @
11383cec
This diff is collapsed.
Click to expand it.
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
11383cec
...
@@ -278,7 +278,7 @@ class SchedulerOutputProcessorMixin:
...
@@ -278,7 +278,7 @@ class SchedulerOutputProcessorMixin:
self
.
attn_tp_rank
==
0
self
.
attn_tp_rank
==
0
and
self
.
forward_ct_decode
%
self
.
server_args
.
decode_log_interval
==
0
and
self
.
forward_ct_decode
%
self
.
server_args
.
decode_log_interval
==
0
):
):
self
.
log_decode_stats
()
self
.
log_decode_stats
(
running_batch
=
batch
)
def
add_input_logprob_return_values
(
def
add_input_logprob_return_values
(
self
:
Scheduler
,
self
:
Scheduler
,
...
...
python/sglang/srt/managers/tp_worker.py
View file @
11383cec
...
@@ -15,11 +15,12 @@
...
@@ -15,11 +15,12 @@
import
logging
import
logging
import
threading
import
threading
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.distributed
import
get_pp_group
,
get_tp_group
,
get_world_group
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
...
@@ -31,7 +32,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -31,7 +32,7 @@ from sglang.srt.managers.io_struct import (
)
)
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
MultiprocessingSerializer
,
broadcast_pyobj
,
set_random_seed
from
sglang.srt.utils
import
MultiprocessingSerializer
,
broadcast_pyobj
,
set_random_seed
...
@@ -47,6 +48,7 @@ class TpModelWorker:
...
@@ -47,6 +48,7 @@ class TpModelWorker:
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
gpu_id
:
int
,
gpu_id
:
int
,
tp_rank
:
int
,
tp_rank
:
int
,
pp_rank
:
int
,
dp_rank
:
Optional
[
int
],
dp_rank
:
Optional
[
int
],
nccl_port
:
int
,
nccl_port
:
int
,
is_draft_worker
:
bool
=
False
,
is_draft_worker
:
bool
=
False
,
...
@@ -54,7 +56,9 @@ class TpModelWorker:
...
@@ -54,7 +56,9 @@ class TpModelWorker:
token_to_kv_pool_allocator
:
Optional
[
TokenToKVPoolAllocator
]
=
None
,
token_to_kv_pool_allocator
:
Optional
[
TokenToKVPoolAllocator
]
=
None
,
):
):
# Parse args
# Parse args
self
.
tp_size
=
server_args
.
tp_size
self
.
tp_rank
=
tp_rank
self
.
tp_rank
=
tp_rank
self
.
pp_rank
=
pp_rank
# Init model and tokenizer
# Init model and tokenizer
self
.
model_config
=
ModelConfig
(
self
.
model_config
=
ModelConfig
(
...
@@ -73,12 +77,15 @@ class TpModelWorker:
...
@@ -73,12 +77,15 @@ class TpModelWorker:
quantization
=
server_args
.
quantization
,
quantization
=
server_args
.
quantization
,
is_draft_model
=
is_draft_worker
,
is_draft_model
=
is_draft_worker
,
)
)
self
.
model_runner
=
ModelRunner
(
self
.
model_runner
=
ModelRunner
(
model_config
=
self
.
model_config
,
model_config
=
self
.
model_config
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
gpu_id
=
gpu_id
,
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
tp_size
=
server_args
.
tp_size
,
tp_size
=
server_args
.
tp_size
,
pp_rank
=
pp_rank
,
pp_size
=
server_args
.
pp_size
,
nccl_port
=
nccl_port
,
nccl_port
=
nccl_port
,
server_args
=
server_args
,
server_args
=
server_args
,
is_draft_worker
=
is_draft_worker
,
is_draft_worker
=
is_draft_worker
,
...
@@ -105,6 +112,10 @@ class TpModelWorker:
...
@@ -105,6 +112,10 @@ class TpModelWorker:
)
)
self
.
device
=
self
.
model_runner
.
device
self
.
device
=
self
.
model_runner
.
device
# Init nccl groups
self
.
pp_group
=
get_pp_group
()
self
.
world_group
=
get_world_group
()
# Profile number of tokens
# Profile number of tokens
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
self
.
max_prefill_tokens
=
server_args
.
max_prefill_tokens
self
.
max_prefill_tokens
=
server_args
.
max_prefill_tokens
...
@@ -130,8 +141,9 @@ class TpModelWorker:
...
@@ -130,8 +141,9 @@ class TpModelWorker:
# Sync random seed across TP workers
# Sync random seed across TP workers
self
.
random_seed
=
broadcast_pyobj
(
self
.
random_seed
=
broadcast_pyobj
(
[
server_args
.
random_seed
],
[
server_args
.
random_seed
],
self
.
tp_rank
,
self
.
tp_size
*
self
.
pp_rank
+
tp_rank
,
self
.
model_runner
.
tp_group
.
cpu_group
,
self
.
world_group
.
cpu_group
,
src
=
self
.
world_group
.
ranks
[
0
],
)[
0
]
)[
0
]
set_random_seed
(
self
.
random_seed
)
set_random_seed
(
self
.
random_seed
)
...
@@ -156,11 +168,14 @@ class TpModelWorker:
...
@@ -156,11 +168,14 @@ class TpModelWorker:
def
get_pad_input_ids_func
(
self
):
def
get_pad_input_ids_func
(
self
):
return
getattr
(
self
.
model_runner
.
model
,
"pad_input_ids"
,
None
)
return
getattr
(
self
.
model_runner
.
model
,
"pad_input_ids"
,
None
)
def
get_tp_cpu_group
(
self
):
def
get_tp_group
(
self
):
return
self
.
model_runner
.
tp_group
.
cpu_group
return
self
.
model_runner
.
tp_group
def
get_attention_tp_group
(
self
):
return
self
.
model_runner
.
attention_tp_group
def
get_attention_tp_cpu_group
(
self
):
def
get_attention_tp_cpu_group
(
self
):
return
self
.
model_runner
.
attention_tp_group
.
cpu_group
return
getattr
(
self
.
model_runner
.
attention_tp_group
,
"
cpu_group
"
,
None
)
def
get_memory_pool
(
self
):
def
get_memory_pool
(
self
):
return
(
return
(
...
@@ -172,19 +187,38 @@ class TpModelWorker:
...
@@ -172,19 +187,38 @@ class TpModelWorker:
self
,
self
,
model_worker_batch
:
ModelWorkerBatch
,
model_worker_batch
:
ModelWorkerBatch
,
skip_sample
:
bool
=
False
,
skip_sample
:
bool
=
False
,
)
->
Tuple
[
LogitsProcessorOutput
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
if
model_worker_batch
.
launch_done
is
not
None
:
pp_proxy_tensors
=
None
model_worker_batch
.
launch_done
.
set
()
if
not
self
.
pp_group
.
is_first_rank
:
pp_proxy_tensors
=
PPProxyTensors
(
self
.
pp_group
.
recv_tensor_dict
(
all_gather_group
=
self
.
get_attention_tp_group
()
)
)
if
self
.
pp_group
.
is_last_rank
:
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
if
model_worker_batch
.
launch_done
is
not
None
:
model_worker_batch
.
launch_done
.
set
()
if
skip_sample
:
if
skip_sample
:
next_token_ids
=
None
next_token_ids
=
None
else
:
else
:
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
model_worker_batch
)
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
model_worker_batch
)
return
logits_output
,
next_token_ids
return
logits_output
,
next_token_ids
else
:
pp_proxy_tensors
=
self
.
model_runner
.
forward
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
return
pp_proxy_tensors
.
tensors
,
None
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
11383cec
...
@@ -56,11 +56,14 @@ class TpModelWorkerClient:
...
@@ -56,11 +56,14 @@ class TpModelWorkerClient:
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
gpu_id
:
int
,
gpu_id
:
int
,
tp_rank
:
int
,
tp_rank
:
int
,
pp_rank
:
int
,
dp_rank
:
Optional
[
int
],
dp_rank
:
Optional
[
int
],
nccl_port
:
int
,
nccl_port
:
int
,
):
):
# Load the model
# Load the model
self
.
worker
=
TpModelWorker
(
server_args
,
gpu_id
,
tp_rank
,
dp_rank
,
nccl_port
)
self
.
worker
=
TpModelWorker
(
server_args
,
gpu_id
,
tp_rank
,
pp_rank
,
dp_rank
,
nccl_port
)
self
.
max_running_requests
=
self
.
worker
.
max_running_requests
self
.
max_running_requests
=
self
.
worker
.
max_running_requests
self
.
device
=
self
.
worker
.
device
self
.
device
=
self
.
worker
.
device
self
.
gpu_id
=
gpu_id
self
.
gpu_id
=
gpu_id
...
@@ -91,8 +94,11 @@ class TpModelWorkerClient:
...
@@ -91,8 +94,11 @@ class TpModelWorkerClient:
def
get_pad_input_ids_func
(
self
):
def
get_pad_input_ids_func
(
self
):
return
self
.
worker
.
get_pad_input_ids_func
()
return
self
.
worker
.
get_pad_input_ids_func
()
def
get_tp_cpu_group
(
self
):
def
get_tp_group
(
self
):
return
self
.
worker
.
get_tp_cpu_group
()
return
self
.
worker
.
get_tp_group
()
def
get_attention_tp_group
(
self
):
return
self
.
worker
.
get_attention_tp_group
()
def
get_attention_tp_cpu_group
(
self
):
def
get_attention_tp_cpu_group
(
self
):
return
self
.
worker
.
get_attention_tp_cpu_group
()
return
self
.
worker
.
get_attention_tp_cpu_group
()
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
11383cec
...
@@ -214,6 +214,8 @@ class MHATokenToKVPool(KVCache):
...
@@ -214,6 +214,8 @@ class MHATokenToKVPool(KVCache):
layer_num
:
int
,
layer_num
:
int
,
device
:
str
,
device
:
str
,
enable_memory_saver
:
bool
,
enable_memory_saver
:
bool
,
start_layer
:
Optional
[
int
]
=
None
,
end_layer
:
Optional
[
int
]
=
None
,
):
):
self
.
size
=
size
self
.
size
=
size
self
.
page_size
=
page_size
self
.
page_size
=
page_size
...
@@ -232,6 +234,8 @@ class MHATokenToKVPool(KVCache):
...
@@ -232,6 +234,8 @@ class MHATokenToKVPool(KVCache):
self
.
head_dim
=
head_dim
self
.
head_dim
=
head_dim
self
.
layer_num
=
layer_num
self
.
layer_num
=
layer_num
self
.
_create_buffers
()
self
.
_create_buffers
()
self
.
start_layer
=
start_layer
or
0
self
.
end_layer
=
end_layer
or
layer_num
-
1
self
.
layer_transfer_counter
=
None
self
.
layer_transfer_counter
=
None
self
.
capture_mode
=
False
self
.
capture_mode
=
False
...
@@ -281,6 +285,8 @@ class MHATokenToKVPool(KVCache):
...
@@ -281,6 +285,8 @@ class MHATokenToKVPool(KVCache):
# for disagg
# for disagg
def
get_contiguous_buf_infos
(
self
):
def
get_contiguous_buf_infos
(
self
):
# layer_num x [seq_len, head_num, head_dim]
# layer_num x [page_num, page_size, head_num, head_dim]
kv_data_ptrs
=
[
kv_data_ptrs
=
[
self
.
get_key_buffer
(
i
).
data_ptr
()
for
i
in
range
(
self
.
layer_num
)
self
.
get_key_buffer
(
i
).
data_ptr
()
for
i
in
range
(
self
.
layer_num
)
]
+
[
self
.
get_value_buffer
(
i
).
data_ptr
()
for
i
in
range
(
self
.
layer_num
)]
]
+
[
self
.
get_value_buffer
(
i
).
data_ptr
()
for
i
in
range
(
self
.
layer_num
)]
...
@@ -320,24 +326,24 @@ class MHATokenToKVPool(KVCache):
...
@@ -320,24 +326,24 @@ class MHATokenToKVPool(KVCache):
# transfer prepared data from host to device
# transfer prepared data from host to device
flat_data
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
flat_data
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
k_data
,
v_data
=
flat_data
[
0
],
flat_data
[
1
]
k_data
,
v_data
=
flat_data
[
0
],
flat_data
[
1
]
self
.
k_buffer
[
layer_id
][
indices
]
=
k_data
self
.
k_buffer
[
layer_id
-
self
.
start_layer
][
indices
]
=
k_data
self
.
v_buffer
[
layer_id
][
indices
]
=
v_data
self
.
v_buffer
[
layer_id
-
self
.
start_layer
][
indices
]
=
v_data
def
get_key_buffer
(
self
,
layer_id
:
int
):
def
get_key_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
)
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
k_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
k_buffer
[
layer_id
-
self
.
start_layer
].
view
(
self
.
dtype
)
return
self
.
k_buffer
[
layer_id
]
return
self
.
k_buffer
[
layer_id
-
self
.
start_layer
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
def
get_value_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
)
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
v_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
v_buffer
[
layer_id
-
self
.
start_layer
].
view
(
self
.
dtype
)
return
self
.
v_buffer
[
layer_id
]
return
self
.
v_buffer
[
layer_id
-
self
.
start_layer
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
get_key_buffer
(
layer_id
),
self
.
get_value_buffer
(
layer_id
)
return
self
.
get_key_buffer
(
layer_id
),
self
.
get_value_buffer
(
layer_id
)
...
@@ -369,12 +375,12 @@ class MHATokenToKVPool(KVCache):
...
@@ -369,12 +375,12 @@ class MHATokenToKVPool(KVCache):
current_stream
=
self
.
device_module
.
current_stream
()
current_stream
=
self
.
device_module
.
current_stream
()
self
.
alt_stream
.
wait_stream
(
current_stream
)
self
.
alt_stream
.
wait_stream
(
current_stream
)
with
self
.
device_module
.
stream
(
self
.
alt_stream
):
with
self
.
device_module
.
stream
(
self
.
alt_stream
):
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
k_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
self
.
v_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_v
current_stream
.
wait_stream
(
self
.
alt_stream
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
else
:
else
:
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
k_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
self
.
v_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_v
@
torch
.
compile
@
torch
.
compile
...
@@ -484,6 +490,8 @@ class MLATokenToKVPool(KVCache):
...
@@ -484,6 +490,8 @@ class MLATokenToKVPool(KVCache):
layer_num
:
int
,
layer_num
:
int
,
device
:
str
,
device
:
str
,
enable_memory_saver
:
bool
,
enable_memory_saver
:
bool
,
start_layer
:
Optional
[
int
]
=
None
,
end_layer
:
Optional
[
int
]
=
None
,
):
):
self
.
size
=
size
self
.
size
=
size
self
.
page_size
=
page_size
self
.
page_size
=
page_size
...
@@ -497,6 +505,8 @@ class MLATokenToKVPool(KVCache):
...
@@ -497,6 +505,8 @@ class MLATokenToKVPool(KVCache):
self
.
kv_lora_rank
=
kv_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
layer_num
=
layer_num
self
.
layer_num
=
layer_num
self
.
start_layer
=
start_layer
or
0
self
.
end_layer
=
end_layer
or
layer_num
-
1
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
enable_memory_saver
enable
=
enable_memory_saver
...
@@ -540,19 +550,21 @@ class MLATokenToKVPool(KVCache):
...
@@ -540,19 +550,21 @@ class MLATokenToKVPool(KVCache):
def
get_key_buffer
(
self
,
layer_id
:
int
):
def
get_key_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
)
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
kv_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
].
view
(
self
.
dtype
)
return
self
.
kv_buffer
[
layer_id
]
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
def
get_value_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
)
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
kv_buffer
[
layer_id
][...,
:
self
.
kv_lora_rank
].
view
(
self
.
dtype
)
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
][
return
self
.
kv_buffer
[
layer_id
][...,
:
self
.
kv_lora_rank
]
...,
:
self
.
kv_lora_rank
].
view
(
self
.
dtype
)
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
][...,
:
self
.
kv_lora_rank
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
get_key_buffer
(
layer_id
),
self
.
get_value_buffer
(
layer_id
)
return
self
.
get_key_buffer
(
layer_id
),
self
.
get_value_buffer
(
layer_id
)
...
@@ -568,9 +580,11 @@ class MLATokenToKVPool(KVCache):
...
@@ -568,9 +580,11 @@ class MLATokenToKVPool(KVCache):
if
cache_k
.
dtype
!=
self
.
dtype
:
if
cache_k
.
dtype
!=
self
.
dtype
:
cache_k
=
cache_k
.
to
(
self
.
dtype
)
cache_k
=
cache_k
.
to
(
self
.
dtype
)
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
self
.
kv_buffer
[
layer_id
][
loc
]
=
cache_k
.
view
(
self
.
store_dtype
)
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_k
.
view
(
self
.
store_dtype
)
else
:
else
:
self
.
kv_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_k
def
set_mla_kv_buffer
(
def
set_mla_kv_buffer
(
self
,
self
,
...
@@ -605,7 +619,7 @@ class MLATokenToKVPool(KVCache):
...
@@ -605,7 +619,7 @@ class MLATokenToKVPool(KVCache):
def
transfer_per_layer
(
self
,
indices
,
flat_data
,
layer_id
):
def
transfer_per_layer
(
self
,
indices
,
flat_data
,
layer_id
):
# transfer prepared data from host to device
# transfer prepared data from host to device
flat_data
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
flat_data
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
self
.
kv_buffer
[
layer_id
][
indices
]
=
flat_data
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
][
indices
]
=
flat_data
class
DoubleSparseTokenToKVPool
(
KVCache
):
class
DoubleSparseTokenToKVPool
(
KVCache
):
...
@@ -620,6 +634,8 @@ class DoubleSparseTokenToKVPool(KVCache):
...
@@ -620,6 +634,8 @@ class DoubleSparseTokenToKVPool(KVCache):
device
:
str
,
device
:
str
,
heavy_channel_num
:
int
,
heavy_channel_num
:
int
,
enable_memory_saver
:
bool
,
enable_memory_saver
:
bool
,
start_layer
:
Optional
[
int
]
=
None
,
end_layer
:
Optional
[
int
]
=
None
,
):
):
self
.
size
=
size
self
.
size
=
size
self
.
page_size
=
page_size
self
.
page_size
=
page_size
...
@@ -657,17 +673,23 @@ class DoubleSparseTokenToKVPool(KVCache):
...
@@ -657,17 +673,23 @@ class DoubleSparseTokenToKVPool(KVCache):
for
_
in
range
(
layer_num
)
for
_
in
range
(
layer_num
)
]
]
self
.
start_layer
=
start_layer
or
0
self
.
end_layer
=
end_layer
or
layer_num
-
1
def
get_key_buffer
(
self
,
layer_id
:
int
):
def
get_key_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
]
return
self
.
k_buffer
[
layer_id
-
self
.
start_layer
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
def
get_value_buffer
(
self
,
layer_id
:
int
):
return
self
.
v_buffer
[
layer_id
]
return
self
.
v_buffer
[
layer_id
-
self
.
start_layer
]
def
get_label_buffer
(
self
,
layer_id
:
int
):
def
get_label_buffer
(
self
,
layer_id
:
int
):
return
self
.
label_buffer
[
layer_id
]
return
self
.
label_buffer
[
layer_id
-
self
.
start_layer
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
],
self
.
v_buffer
[
layer_id
]
return
(
self
.
k_buffer
[
layer_id
-
self
.
start_layer
],
self
.
v_buffer
[
layer_id
-
self
.
start_layer
],
)
def
set_kv_buffer
(
def
set_kv_buffer
(
self
,
self
,
...
@@ -679,9 +701,9 @@ class DoubleSparseTokenToKVPool(KVCache):
...
@@ -679,9 +701,9 @@ class DoubleSparseTokenToKVPool(KVCache):
):
):
# NOTE(Andy): ignore the dtype check
# NOTE(Andy): ignore the dtype check
layer_id
=
layer
.
layer_id
layer_id
=
layer
.
layer_id
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
k_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
self
.
v_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_v
self
.
label_buffer
[
layer_id
][
loc
]
=
cache_label
self
.
label_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_label
def
get_flat_data
(
self
,
indices
):
def
get_flat_data
(
self
,
indices
):
pass
pass
...
@@ -930,7 +952,7 @@ class MHATokenToKVPoolHost(HostKVCache):
...
@@ -930,7 +952,7 @@ class MHATokenToKVPoolHost(HostKVCache):
return
self
.
kv_buffer
[:,
:,
indices
]
return
self
.
kv_buffer
[:,
:,
indices
]
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
return
self
.
kv_buffer
[:,
layer_id
,
indices
]
return
self
.
kv_buffer
[:,
layer_id
-
self
.
start_layer
,
indices
]
def
assign_flat_data
(
self
,
indices
,
flat_data
):
def
assign_flat_data
(
self
,
indices
,
flat_data
):
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
...
@@ -955,12 +977,20 @@ class MHATokenToKVPoolHost(HostKVCache):
...
@@ -955,12 +977,20 @@ class MHATokenToKVPoolHost(HostKVCache):
for
i
in
range
(
len
(
device_indices_cpu
)):
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
d_index
=
device_indices_cpu
[
i
]
device_pool
.
k_buffer
[
layer_id
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
device_pool
.
k_buffer
[
layer_id
-
self
.
start_layer
][
self
.
kv_buffer
[
0
,
layer_id
,
h_index
:
h_index
+
self
.
page_size
],
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
0
,
layer_id
-
self
.
start_layer
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
non_blocking
=
True
,
)
)
device_pool
.
v_buffer
[
layer_id
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
device_pool
.
v_buffer
[
layer_id
-
self
.
start_layer
][
self
.
kv_buffer
[
1
,
layer_id
,
h_index
:
h_index
+
self
.
page_size
],
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
1
,
layer_id
-
self
.
start_layer
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
non_blocking
=
True
,
)
)
...
@@ -1015,7 +1045,7 @@ class MLATokenToKVPoolHost(HostKVCache):
...
@@ -1015,7 +1045,7 @@ class MLATokenToKVPoolHost(HostKVCache):
return
self
.
kv_buffer
[:,
indices
]
return
self
.
kv_buffer
[:,
indices
]
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
return
self
.
kv_buffer
[
layer_id
,
indices
]
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
,
indices
]
def
assign_flat_data
(
self
,
indices
,
flat_data
):
def
assign_flat_data
(
self
,
indices
,
flat_data
):
self
.
kv_buffer
[:,
indices
]
=
flat_data
self
.
kv_buffer
[:,
indices
]
=
flat_data
...
@@ -1036,7 +1066,11 @@ class MLATokenToKVPoolHost(HostKVCache):
...
@@ -1036,7 +1066,11 @@ class MLATokenToKVPoolHost(HostKVCache):
for
i
in
range
(
len
(
device_indices_cpu
)):
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
d_index
=
device_indices_cpu
[
i
]
device_pool
.
kv_buffer
[
layer_id
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
device_pool
.
kv_buffer
[
layer_id
-
self
.
start_layer
][
self
.
kv_buffer
[
layer_id
,
h_index
:
h_index
+
self
.
page_size
],
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
non_blocking
=
True
,
)
)
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
11383cec
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
from
__future__
import
annotations
from
__future__
import
annotations
import
bisect
import
bisect
import
inspect
import
os
import
os
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Callable
from
typing
import
TYPE_CHECKING
,
Callable
...
@@ -33,12 +34,14 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -33,12 +34,14 @@ from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode
,
CaptureHiddenMode
,
ForwardBatch
,
ForwardBatch
,
ForwardMode
,
ForwardMode
,
PPProxyTensors
,
)
)
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
get_available_gpu_memory
,
get_device_memory_capacity
,
get_device_memory_capacity
,
is_hip
,
is_hip
,
rank0_log
,
)
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -188,10 +191,11 @@ class CudaGraphRunner:
...
@@ -188,10 +191,11 @@ class CudaGraphRunner:
self
.
speculative_algorithm
=
model_runner
.
server_args
.
speculative_algorithm
self
.
speculative_algorithm
=
model_runner
.
server_args
.
speculative_algorithm
self
.
tp_size
=
model_runner
.
server_args
.
tp_size
self
.
tp_size
=
model_runner
.
server_args
.
tp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
self
.
pp_size
=
model_runner
.
server_args
.
pp_size
# Batch sizes to capture
# Batch sizes to capture
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
rank0_log
(
f
"Capture cuda graph bs
{
self
.
capture_bs
}
"
)
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
self
.
num_tokens_per_bs
=
1
self
.
num_tokens_per_bs
=
1
...
@@ -234,6 +238,19 @@ class CudaGraphRunner:
...
@@ -234,6 +238,19 @@ class CudaGraphRunner:
self
.
positions
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
positions
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int64
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int64
)
# pipeline parallelism
if
self
.
pp_size
>
1
:
self
.
pp_proxy_tensors
=
{
"hidden_states"
:
torch
.
zeros
(
(
self
.
max_bs
,
self
.
model_runner
.
model_config
.
hidden_size
),
dtype
=
torch
.
bfloat16
,
),
"residual"
:
torch
.
zeros
(
(
self
.
max_bs
,
self
.
model_runner
.
model_config
.
hidden_size
),
dtype
=
torch
.
bfloat16
,
),
}
# Speculative_inference
# Speculative_inference
if
(
if
(
model_runner
.
spec_algorithm
.
is_eagle3
()
model_runner
.
spec_algorithm
.
is_eagle3
()
...
@@ -384,6 +401,12 @@ class CudaGraphRunner:
...
@@ -384,6 +401,12 @@ class CudaGraphRunner:
encoder_lens
=
None
encoder_lens
=
None
mrope_positions
=
self
.
mrope_positions
[:,
:
bs
]
mrope_positions
=
self
.
mrope_positions
[:,
:
bs
]
# pipeline parallelism
if
self
.
pp_size
>
1
:
pp_proxy_tensors
=
PPProxyTensors
(
{
k
:
v
[:
num_tokens
]
for
k
,
v
in
self
.
pp_proxy_tensors
.
items
()}
)
if
self
.
enable_dp_attention
or
self
.
enable_sp_layernorm
:
if
self
.
enable_dp_attention
or
self
.
enable_sp_layernorm
:
self
.
global_num_tokens_gpu
.
copy_
(
self
.
global_num_tokens_gpu
.
copy_
(
torch
.
tensor
(
torch
.
tensor
(
...
@@ -456,8 +479,20 @@ class CudaGraphRunner:
...
@@ -456,8 +479,20 @@ class CudaGraphRunner:
# Clean intermediate result cache for DP attention
# Clean intermediate result cache for DP attention
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
logits_output
=
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
)
kwargs
=
{}
return
logits_output
.
next_token_logits
,
logits_output
.
hidden_states
if
(
self
.
pp_size
>
1
and
"pp_proxy_tensors"
in
inspect
.
signature
(
forward
).
parameters
):
kwargs
[
"pp_proxy_tensors"
]
=
pp_proxy_tensors
logits_output_or_pp_proxy_tensors
=
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
,
**
kwargs
,
)
return
logits_output_or_pp_proxy_tensors
for
_
in
range
(
2
):
for
_
in
range
(
2
):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -490,7 +525,11 @@ class CudaGraphRunner:
...
@@ -490,7 +525,11 @@ class CudaGraphRunner:
self
.
capture_hidden_mode
=
hidden_mode_from_spec_info
self
.
capture_hidden_mode
=
hidden_mode_from_spec_info
self
.
capture
()
self
.
capture
()
def
replay_prepare
(
self
,
forward_batch
:
ForwardBatch
):
def
replay_prepare
(
self
,
forward_batch
:
ForwardBatch
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
):
self
.
recapture_if_needed
(
forward_batch
)
self
.
recapture_if_needed
(
forward_batch
)
raw_bs
=
forward_batch
.
batch_size
raw_bs
=
forward_batch
.
batch_size
...
@@ -519,6 +558,11 @@ class CudaGraphRunner:
...
@@ -519,6 +558,11 @@ class CudaGraphRunner:
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
if
pp_proxy_tensors
:
for
key
in
self
.
pp_proxy_tensors
.
keys
():
dim
=
pp_proxy_tensors
[
key
].
shape
[
0
]
self
.
pp_proxy_tensors
[
key
][:
dim
].
copy_
(
pp_proxy_tensors
[
key
])
if
self
.
is_encoder_decoder
:
if
self
.
is_encoder_decoder
:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
if
forward_batch
.
mrope_positions
is
not
None
:
if
forward_batch
.
mrope_positions
is
not
None
:
...
@@ -547,10 +591,13 @@ class CudaGraphRunner:
...
@@ -547,10 +591,13 @@ class CudaGraphRunner:
self
.
bs
=
bs
self
.
bs
=
bs
def
replay
(
def
replay
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
self
,
)
->
LogitsProcessorOutput
:
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
LogitsProcessorOutput
,
PPProxyTensors
]:
if
not
skip_attn_backend_init
:
if
not
skip_attn_backend_init
:
self
.
replay_prepare
(
forward_batch
)
self
.
replay_prepare
(
forward_batch
,
pp_proxy_tensors
)
else
:
else
:
# In speculative decoding, these two fields are still needed.
# In speculative decoding, these two fields are still needed.
self
.
input_ids
[:
self
.
raw_num_token
].
copy_
(
forward_batch
.
input_ids
)
self
.
input_ids
[:
self
.
raw_num_token
].
copy_
(
forward_batch
.
input_ids
)
...
@@ -558,17 +605,19 @@ class CudaGraphRunner:
...
@@ -558,17 +605,19 @@ class CudaGraphRunner:
# Replay
# Replay
self
.
graphs
[
self
.
bs
].
replay
()
self
.
graphs
[
self
.
bs
].
replay
()
next_token_logits
,
hidden_states
=
self
.
output_buffers
[
self
.
bs
]
output
=
self
.
output_buffers
[
self
.
bs
]
if
isinstance
(
output
,
LogitsProcessorOutput
):
logits_output
=
LogitsProcessorOutput
(
return
LogitsProcessorOutput
(
next_token_logits
=
next_token_logits
[:
self
.
raw_num_token
],
next_token_logits
=
output
.
next_token_logits
[:
self
.
raw_num_token
],
hidden_states
=
(
hidden_states
=
(
hidden_states
[:
self
.
raw_num_token
]
output
.
hidden_states
[:
self
.
raw_num_token
]
if
hidden_states
is
not
None
if
output
.
hidden_states
is
not
None
else
None
else
None
),
),
)
)
return
logits_output
else
:
assert
isinstance
(
output
,
PPProxyTensors
)
return
PPProxyTensors
({
k
:
v
[:
self
.
bs
]
for
k
,
v
in
output
.
tensors
.
items
()})
def
get_spec_info
(
self
,
num_tokens
:
int
):
def
get_spec_info
(
self
,
num_tokens
:
int
):
spec_info
=
None
spec_info
=
None
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
11383cec
...
@@ -31,7 +31,7 @@ from __future__ import annotations
...
@@ -31,7 +31,7 @@ from __future__ import annotations
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
import
triton
import
triton
...
@@ -585,6 +585,36 @@ class ForwardBatch:
...
@@ -585,6 +585,36 @@ class ForwardBatch:
self
.
prepare_chunked_kv_indices
(
device
)
self
.
prepare_chunked_kv_indices
(
device
)
class
PPProxyTensors
:
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
tensors
:
Dict
[
str
,
torch
.
Tensor
]
def
__init__
(
self
,
tensors
):
# manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file.
self
.
tensors
=
tensors
def
__getitem__
(
self
,
key
:
Union
[
str
,
slice
]):
if
isinstance
(
key
,
str
):
return
self
.
tensors
[
key
]
elif
isinstance
(
key
,
slice
):
return
self
.
__class__
({
k
:
v
[
key
]
for
k
,
v
in
self
.
tensors
.
items
()})
def
__setitem__
(
self
,
key
:
str
,
value
:
torch
.
Tensor
):
self
.
tensors
[
key
]
=
value
def
__len__
(
self
):
return
len
(
self
.
tensors
)
def
__eq__
(
self
,
other
:
object
):
return
isinstance
(
other
,
self
.
__class__
)
and
self
def
__repr__
(
self
)
->
str
:
return
f
"PPProxyTensors(tensors=
{
self
.
tensors
}
)"
def
compute_position_triton
(
def
compute_position_triton
(
extend_prefix_lens
:
torch
.
Tensor
,
extend_seq_lens
:
torch
.
Tensor
,
extend_seq_lens_sum
extend_prefix_lens
:
torch
.
Tensor
,
extend_seq_lens
:
torch
.
Tensor
,
extend_seq_lens_sum
):
):
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
11383cec
...
@@ -13,8 +13,10 @@
...
@@ -13,8 +13,10 @@
# ==============================================================================
# ==============================================================================
"""ModelRunner runs the forward passes of the models."""
"""ModelRunner runs the forward passes of the models."""
import
collections
import
datetime
import
datetime
import
gc
import
gc
import
inspect
import
json
import
json
import
logging
import
logging
import
os
import
os
...
@@ -59,7 +61,7 @@ from sglang.srt.mem_cache.memory_pool import (
...
@@ -59,7 +61,7 @@ from sglang.srt.mem_cache.memory_pool import (
)
)
from
sglang.srt.mem_cache.paged_allocator
import
PagedTokenToKVPoolAllocator
from
sglang.srt.mem_cache.paged_allocator
import
PagedTokenToKVPoolAllocator
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.model_loader.loader
import
(
from
sglang.srt.model_loader.loader
import
(
DefaultModelLoader
,
DefaultModelLoader
,
...
@@ -111,6 +113,8 @@ class ModelRunner:
...
@@ -111,6 +113,8 @@ class ModelRunner:
gpu_id
:
int
,
gpu_id
:
int
,
tp_rank
:
int
,
tp_rank
:
int
,
tp_size
:
int
,
tp_size
:
int
,
pp_rank
:
int
,
pp_size
:
int
,
nccl_port
:
int
,
nccl_port
:
int
,
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
is_draft_worker
:
bool
=
False
,
is_draft_worker
:
bool
=
False
,
...
@@ -124,6 +128,8 @@ class ModelRunner:
...
@@ -124,6 +128,8 @@ class ModelRunner:
self
.
gpu_id
=
gpu_id
self
.
gpu_id
=
gpu_id
self
.
tp_rank
=
tp_rank
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
tp_size
self
.
tp_size
=
tp_size
self
.
pp_rank
=
pp_rank
self
.
pp_size
=
pp_size
self
.
dist_port
=
nccl_port
self
.
dist_port
=
nccl_port
self
.
server_args
=
server_args
self
.
server_args
=
server_args
self
.
is_draft_worker
=
is_draft_worker
self
.
is_draft_worker
=
is_draft_worker
...
@@ -149,24 +155,24 @@ class ModelRunner:
...
@@ -149,24 +155,24 @@ class ModelRunner:
global_server_args_dict
.
update
(
global_server_args_dict
.
update
(
{
{
"attention_backend"
:
server_args
.
attention_backend
,
"attention_backend"
:
server_args
.
attention_backend
,
"sampling_backend"
:
server_args
.
sampling_backend
,
"debug_tensor_dump_inject"
:
server_args
.
debug_tensor_dump_inject
,
"triton_attention_reduce_in_fp32"
:
server_args
.
triton_attention_reduce_in_fp32
,
"debug_tensor_dump_output_folder"
:
server_args
.
debug_tensor_dump_output_folder
,
"torchao_config"
:
server_args
.
torchao_config
,
"deepep_mode"
:
server_args
.
deepep_mode
,
"device"
:
server_args
.
device
,
"disable_chunked_prefix_cache"
:
server_args
.
disable_chunked_prefix_cache
,
"disable_radix_cache"
:
server_args
.
disable_radix_cache
,
"enable_nan_detection"
:
server_args
.
enable_nan_detection
,
"enable_nan_detection"
:
server_args
.
enable_nan_detection
,
"enable_dp_attention"
:
server_args
.
enable_dp_attention
,
"enable_dp_attention"
:
server_args
.
enable_dp_attention
,
"enable_ep_moe"
:
server_args
.
enable_ep_moe
,
"enable_ep_moe"
:
server_args
.
enable_ep_moe
,
"enable_deepep_moe"
:
server_args
.
enable_deepep_moe
,
"enable_deepep_moe"
:
server_args
.
enable_deepep_moe
,
"deepep_mode"
:
server_args
.
deepep_mode
,
"device"
:
server_args
.
device
,
"speculative_accept_threshold_single"
:
server_args
.
speculative_accept_threshold_single
,
"speculative_accept_threshold_acc"
:
server_args
.
speculative_accept_threshold_acc
,
"disable_radix_cache"
:
server_args
.
disable_radix_cache
,
"flashinfer_mla_disable_ragged"
:
server_args
.
flashinfer_mla_disable_ragged
,
"flashinfer_mla_disable_ragged"
:
server_args
.
flashinfer_mla_disable_ragged
,
"moe_dense_tp_size"
:
server_args
.
moe_dense_tp_size
,
"moe_dense_tp_size"
:
server_args
.
moe_dense_tp_size
,
"debug_tensor_dump_output_folder"
:
server_args
.
debug_tensor_dump_output_folder
,
"debug_tensor_dump_inject"
:
server_args
.
debug_tensor_dump_inject
,
"n_share_experts_fusion"
:
server_args
.
n_share_experts_fusion
,
"n_share_experts_fusion"
:
server_args
.
n_share_experts_fusion
,
"disable_chunked_prefix_cache"
:
server_args
.
disable_chunked_prefix_cache
,
"triton_attention_reduce_in_fp32"
:
server_args
.
triton_attention_reduce_in_fp32
,
"torchao_config"
:
server_args
.
torchao_config
,
"sampling_backend"
:
server_args
.
sampling_backend
,
"speculative_accept_threshold_single"
:
server_args
.
speculative_accept_threshold_single
,
"speculative_accept_threshold_acc"
:
server_args
.
speculative_accept_threshold_acc
,
"use_mla_backend"
:
self
.
use_mla_backend
,
"use_mla_backend"
:
self
.
use_mla_backend
,
}
}
)
)
...
@@ -184,6 +190,11 @@ class ModelRunner:
...
@@ -184,6 +190,11 @@ class ModelRunner:
# If it is a draft model, tp_group can be different
# If it is a draft model, tp_group can be different
self
.
initialize
(
min_per_gpu_memory
)
self
.
initialize
(
min_per_gpu_memory
)
# temporary cached values
self
.
support_pp
=
(
"pp_proxy_tensors"
in
inspect
.
signature
(
self
.
model
.
forward
).
parameters
)
def
initialize
(
self
,
min_per_gpu_memory
:
float
):
def
initialize
(
self
,
min_per_gpu_memory
:
float
):
server_args
=
self
.
server_args
server_args
=
self
.
server_args
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
...
@@ -194,6 +205,12 @@ class ModelRunner:
...
@@ -194,6 +205,12 @@ class ModelRunner:
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
load_model
()
self
.
load_model
()
self
.
start_layer
=
getattr
(
self
.
model
,
"start_layer"
,
0
)
self
.
end_layer
=
getattr
(
self
.
model
,
"end_layer"
,
self
.
model_config
.
num_hidden_layers
)
self
.
num_effective_layers
=
self
.
end_layer
-
self
.
start_layer
# Apply torchao quantization
# Apply torchao quantization
torchao_applied
=
getattr
(
self
.
model
,
"torchao_applied"
,
False
)
torchao_applied
=
getattr
(
self
.
model
,
"torchao_applied"
,
False
)
# In layered loading, torchao may have been applied
# In layered loading, torchao may have been applied
...
@@ -360,18 +377,22 @@ class ModelRunner:
...
@@ -360,18 +377,22 @@ class ModelRunner:
# Only initialize the distributed environment on the target model worker.
# Only initialize the distributed environment on the target model worker.
init_distributed_environment
(
init_distributed_environment
(
backend
=
backend
,
backend
=
backend
,
world_size
=
self
.
tp_size
,
world_size
=
self
.
tp_size
*
self
.
pp_size
,
rank
=
self
.
tp_rank
,
rank
=
self
.
tp_size
*
self
.
pp_rank
+
self
.
tp_rank
,
local_rank
=
self
.
gpu_id
,
local_rank
=
self
.
gpu_id
,
distributed_init_method
=
dist_init_method
,
distributed_init_method
=
dist_init_method
,
timeout
=
self
.
server_args
.
dist_timeout
,
timeout
=
self
.
server_args
.
dist_timeout
,
)
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
,
pipeline_model_parallel_size
=
self
.
pp_size
,
)
initialize_dp_attention
(
initialize_dp_attention
(
enable_dp_attention
=
self
.
server_args
.
enable_dp_attention
,
enable_dp_attention
=
self
.
server_args
.
enable_dp_attention
,
tp_rank
=
self
.
tp_rank
,
tp_rank
=
self
.
tp_rank
,
tp_size
=
self
.
tp_size
,
tp_size
=
self
.
tp_size
,
dp_size
=
self
.
server_args
.
dp_size
,
dp_size
=
self
.
server_args
.
dp_size
,
pp_size
=
self
.
server_args
.
pp_size
,
)
)
min_per_gpu_memory
=
get_available_gpu_memory
(
min_per_gpu_memory
=
get_available_gpu_memory
(
...
@@ -698,6 +719,8 @@ class ModelRunner:
...
@@ -698,6 +719,8 @@ class ModelRunner:
if
not
self
.
is_draft_worker
if
not
self
.
is_draft_worker
else
self
.
model_config
.
hf_config
.
num_nextn_predict_layers
else
self
.
model_config
.
hf_config
.
num_nextn_predict_layers
)
)
# FIXME: pipeline parallelism is not compatible with mla backend
assert
self
.
pp_size
==
1
cell_size
=
(
cell_size
=
(
(
self
.
model_config
.
kv_lora_rank
+
self
.
model_config
.
qk_rope_head_dim
)
(
self
.
model_config
.
kv_lora_rank
+
self
.
model_config
.
qk_rope_head_dim
)
*
num_layers
*
num_layers
...
@@ -707,7 +730,7 @@ class ModelRunner:
...
@@ -707,7 +730,7 @@ class ModelRunner:
cell_size
=
(
cell_size
=
(
self
.
model_config
.
get_num_kv_heads
(
get_attention_tp_size
())
self
.
model_config
.
get_num_kv_heads
(
get_attention_tp_size
())
*
self
.
model_config
.
head_dim
*
self
.
model_config
.
head_dim
*
self
.
model_config
.
num_hidden
_layers
*
self
.
num_effective
_layers
*
2
*
2
*
torch
.
_utils
.
_element_size
(
self
.
kv_cache_dtype
)
*
torch
.
_utils
.
_element_size
(
self
.
kv_cache_dtype
)
)
)
...
@@ -819,9 +842,11 @@ class ModelRunner:
...
@@ -819,9 +842,11 @@ class ModelRunner:
self
.
model_config
.
num_hidden_layers
self
.
model_config
.
num_hidden_layers
if
not
self
.
is_draft_worker
if
not
self
.
is_draft_worker
else
self
.
model_config
.
hf_config
.
num_nextn_predict_layers
else
self
.
model_config
.
hf_config
.
num_nextn_predict_layers
),
),
# PP is not compatible with mla backend
device
=
self
.
device
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
start_layer
=
self
.
start_layer
,
end_layer
=
self
.
end_layer
,
)
)
elif
self
.
server_args
.
enable_double_sparsity
:
elif
self
.
server_args
.
enable_double_sparsity
:
self
.
token_to_kv_pool
=
DoubleSparseTokenToKVPool
(
self
.
token_to_kv_pool
=
DoubleSparseTokenToKVPool
(
...
@@ -830,10 +855,12 @@ class ModelRunner:
...
@@ -830,10 +855,12 @@ class ModelRunner:
dtype
=
self
.
kv_cache_dtype
,
dtype
=
self
.
kv_cache_dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
get_attention_tp_size
()),
head_num
=
self
.
model_config
.
get_num_kv_heads
(
get_attention_tp_size
()),
head_dim
=
self
.
model_config
.
head_dim
,
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden
_layers
,
layer_num
=
self
.
num_effective
_layers
,
device
=
self
.
device
,
device
=
self
.
device
,
heavy_channel_num
=
self
.
server_args
.
ds_heavy_channel_num
,
heavy_channel_num
=
self
.
server_args
.
ds_heavy_channel_num
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
start_layer
=
self
.
start_layer
,
end_layer
=
self
.
end_layer
,
)
)
else
:
else
:
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
...
@@ -842,9 +869,11 @@ class ModelRunner:
...
@@ -842,9 +869,11 @@ class ModelRunner:
dtype
=
self
.
kv_cache_dtype
,
dtype
=
self
.
kv_cache_dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
get_attention_tp_size
()),
head_num
=
self
.
model_config
.
get_num_kv_heads
(
get_attention_tp_size
()),
head_dim
=
self
.
model_config
.
head_dim
,
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden
_layers
,
layer_num
=
self
.
num_effective
_layers
,
device
=
self
.
device
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
start_layer
=
self
.
start_layer
,
end_layer
=
self
.
end_layer
,
)
)
if
self
.
token_to_kv_pool_allocator
is
None
:
if
self
.
token_to_kv_pool_allocator
is
None
:
...
@@ -957,7 +986,7 @@ class ModelRunner:
...
@@ -957,7 +986,7 @@ class ModelRunner:
with
open
(
self
.
server_args
.
ds_channel_config_path
,
"r"
)
as
f
:
with
open
(
self
.
server_args
.
ds_channel_config_path
,
"r"
)
as
f
:
channel_config
=
json
.
load
(
f
)
channel_config
=
json
.
load
(
f
)
for
i
in
range
(
self
.
model_config
.
num_hidd
en_layer
s
):
for
i
in
range
(
self
.
start_layer
,
self
.
en
d
_layer
):
key
=
"model.layers."
+
str
(
i
)
+
".self_attn"
+
selected_channel
key
=
"model.layers."
+
str
(
i
)
+
".self_attn"
+
selected_channel
self
.
sorted_channels
.
append
(
self
.
sorted_channels
.
append
(
torch
.
tensor
(
channel_config
[
key
])[
torch
.
tensor
(
channel_config
[
key
])[
...
@@ -997,64 +1026,82 @@ class ModelRunner:
...
@@ -997,64 +1026,82 @@ class ModelRunner:
device_mesh
=
torch
.
distributed
.
init_device_mesh
(
self
.
device
,
(
self
.
tp_size
,))
device_mesh
=
torch
.
distributed
.
init_device_mesh
(
self
.
device
,
(
self
.
tp_size
,))
tensor_parallel
(
self
.
model
,
device_mesh
)
tensor_parallel
(
self
.
model
,
device_mesh
)
def
forward_decode
(
self
,
forward_batch
:
ForwardBatch
):
def
forward_decode
(
self
,
forward_batch
:
ForwardBatch
,
pp_proxy_tensors
=
None
)
->
LogitsProcessorOutput
:
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
# FIXME: add pp_proxy_tensors arg to all models
kwargs
=
{}
if
self
.
support_pp
:
kwargs
[
"pp_proxy_tensors"
]
=
pp_proxy_tensors
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
,
**
kwargs
)
)
def
forward_extend
(
def
forward_extend
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
self
,
):
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
,
pp_proxy_tensors
=
None
,
)
->
LogitsProcessorOutput
:
if
not
skip_attn_backend_init
:
if
not
skip_attn_backend_init
:
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
if
self
.
is_generation
:
kwargs
=
{}
if
forward_batch
.
input_embeds
is
None
:
if
self
.
support_pp
:
return
self
.
model
.
forward
(
kwargs
[
"pp_proxy_tensors"
]
=
pp_proxy_tensors
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
if
forward_batch
.
input_embeds
is
not
None
:
)
kwargs
[
"input_embeds"
]
=
forward_batch
.
input_embeds
.
bfloat16
()
else
:
if
not
self
.
is_generation
:
return
self
.
model
.
forward
(
kwargs
[
"get_embedding"
]
=
True
forward_batch
.
input_ids
,
return
self
.
model
.
forward
(
forward_batch
.
positions
,
forward_batch
.
input_ids
,
forward_batch
,
forward_batch
.
positions
,
input_embeds
=
forward_batch
.
input_embeds
.
bfloat16
(),
forward_batch
,
)
**
kwargs
,
else
:
)
# Only embedding models have get_embedding parameter
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
,
get_embedding
=
True
,
)
def
forward_idle
(
self
,
forward_batch
:
ForwardBatch
):
def
forward_idle
(
self
,
forward_batch
:
ForwardBatch
,
pp_proxy_tensors
=
None
)
->
LogitsProcessorOutput
:
kwargs
=
{}
if
self
.
support_pp
:
kwargs
[
"pp_proxy_tensors"
]
=
pp_proxy_tensors
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
,
**
kwargs
,
)
)
def
forward
(
def
forward
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
self
,
)
->
LogitsProcessorOutput
:
forward_batch
:
ForwardBatch
,
if
(
skip_attn_backend_init
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
LogitsProcessorOutput
,
PPProxyTensors
]:
can_run_cuda_graph
=
bool
(
forward_batch
.
forward_mode
.
is_cuda_graph
()
forward_batch
.
forward_mode
.
is_cuda_graph
()
and
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
)
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
)
):
)
if
can_run_cuda_graph
:
return
self
.
cuda_graph_runner
.
replay
(
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
)
if
forward_batch
.
forward_mode
.
is_decode
():
if
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
forward_batch
)
return
self
.
forward_decode
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
elif
forward_batch
.
forward_mode
.
is_extend
():
elif
forward_batch
.
forward_mode
.
is_extend
():
return
self
.
forward_extend
(
return
self
.
forward_extend
(
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
)
elif
forward_batch
.
forward_mode
.
is_idle
():
elif
forward_batch
.
forward_mode
.
is_idle
():
return
self
.
forward_idle
(
forward_batch
)
return
self
.
forward_idle
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
else
:
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_batch
.
forward_mode
}
"
)
raise
ValueError
(
f
"Invalid forward mode:
{
forward_batch
.
forward_mode
}
"
)
...
...
python/sglang/srt/models/llama.py
View file @
11383cec
...
@@ -17,13 +17,14 @@
...
@@ -17,13 +17,14 @@
"""Inference-only LLaMA model compatible with HuggingFace weights."""
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import
logging
import
logging
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
...
@@ -39,11 +40,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
...
@@ -39,11 +40,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
kv_cache_scales_loader
,
kv_cache_scales_loader
,
...
@@ -275,21 +277,31 @@ class LlamaModel(nn.Module):
...
@@ -275,21 +277,31 @@ class LlamaModel(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
pp_group
=
get_pp_group
()
config
.
vocab_size
,
if
self
.
pp_group
.
is_first_rank
:
config
.
hidden_size
,
self
.
embed_tokens
=
VocabParallelEmbedding
(
quant_config
=
quant_config
,
config
.
vocab_size
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
config
.
hidden_size
,
)
quant_config
=
quant_config
,
self
.
layers
=
make_layers
(
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
idx
,
prefix
:
LlamaDecoderLayer
(
lambda
idx
,
prefix
:
LlamaDecoderLayer
(
config
=
config
,
layer_id
=
idx
,
quant_config
=
quant_config
,
prefix
=
prefix
config
=
config
,
quant_config
=
quant_config
,
layer_id
=
idx
,
prefix
=
prefix
),
),
pp_rank
=
self
.
pp_group
.
rank_in_group
,
pp_size
=
self
.
pp_group
.
world_size
,
prefix
=
"model.layers"
,
prefix
=
"model.layers"
,
)
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
if
self
.
pp_group
.
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
self
.
layers_to_capture
=
[]
self
.
layers_to_capture
=
[]
def
forward
(
def
forward
(
...
@@ -298,14 +310,23 @@ class LlamaModel(nn.Module):
...
@@ -298,14 +310,23 @@ class LlamaModel(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]:
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
if
input_embeds
is
None
:
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
PPProxyTensors
]:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
self
.
pp_group
.
is_first_rank
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
residual
=
None
else
:
else
:
hidden_states
=
input_embeds
assert
pp_proxy_tensors
is
not
None
residual
=
None
# FIXME(@ying): reduce the number of proxy tensors by not fusing layer norms
hidden_states
=
pp_proxy_tensors
[
"hidden_states"
]
residual
=
pp_proxy_tensors
[
"residual"
]
deferred_norm
=
None
aux_hidden_states
=
[]
aux_hidden_states
=
[]
for
i
in
range
(
len
(
self
.
layer
s
)
):
for
i
in
range
(
self
.
start_layer
,
self
.
end_
layer
):
if
i
in
self
.
layers_to_capture
:
if
i
in
self
.
layers_to_capture
:
aux_hidden_states
.
append
(
hidden_states
+
residual
)
aux_hidden_states
.
append
(
hidden_states
+
residual
)
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
...
@@ -315,7 +336,16 @@ class LlamaModel(nn.Module):
...
@@ -315,7 +336,16 @@ class LlamaModel(nn.Module):
forward_batch
,
forward_batch
,
residual
,
residual
,
)
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
not
self
.
pp_group
.
is_last_rank
:
return
PPProxyTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
,
}
)
else
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
len
(
aux_hidden_states
)
==
0
:
if
len
(
aux_hidden_states
)
==
0
:
return
hidden_states
return
hidden_states
...
@@ -376,6 +406,7 @@ class LlamaForCausalLM(nn.Module):
...
@@ -376,6 +406,7 @@ class LlamaForCausalLM(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
self
.
_init_model
(
config
,
quant_config
,
add_prefix
(
"model"
,
prefix
))
self
.
model
=
self
.
_init_model
(
config
,
quant_config
,
add_prefix
(
"model"
,
prefix
))
...
@@ -419,23 +450,41 @@ class LlamaForCausalLM(nn.Module):
...
@@ -419,23 +450,41 @@ class LlamaForCausalLM(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
get_embedding
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
LogitsProcessorOutput
:
)
->
LogitsProcessorOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
aux_hidden_states
=
None
aux_hidden_states
=
None
if
self
.
capture_aux_hidden_states
:
if
self
.
capture_aux_hidden_states
:
hidden_states
,
aux_hidden_states
=
self
.
model
(
hidden_states
,
aux_hidden_states
=
hidden_states
input_ids
,
positions
,
forward_batch
,
input_embeds
)
if
self
.
pp_group
.
is_last_rank
:
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
,
aux_hidden_states
,
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
else
:
else
:
hidden_states
=
self
.
model
(
return
hidden_states
input_ids
,
positions
,
forward_batch
,
input_embeds
)
if
not
get_embedding
:
@
property
return
self
.
logits_processor
(
def
start_layer
(
self
):
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
,
aux_hidden_states
return
self
.
model
.
start_layer
)
else
:
@
property
return
self
.
pooler
(
hidden_states
,
forward_batch
)
def
end_layer
(
self
):
return
self
.
model
.
end_layer
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
model
.
embed_tokens
return
self
.
model
.
embed_tokens
...
@@ -491,6 +540,16 @@ class LlamaForCausalLM(nn.Module):
...
@@ -491,6 +540,16 @@ class LlamaForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
layer_id
=
get_layer_id
(
name
)
if
(
layer_id
is
not
None
and
hasattr
(
self
.
model
,
"start_layer"
)
and
(
layer_id
<
self
.
model
.
start_layer
or
layer_id
>=
self
.
model
.
end_layer
)
):
continue
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
continue
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
...
@@ -637,6 +696,9 @@ class LlamaForCausalLM(nn.Module):
...
@@ -637,6 +696,9 @@ class LlamaForCausalLM(nn.Module):
self
.
model
.
load_kv_cache_scales
(
quantization_param_path
)
self
.
model
.
load_kv_cache_scales
(
quantization_param_path
)
def
set_eagle3_layers_to_capture
(
self
):
def
set_eagle3_layers_to_capture
(
self
):
if
not
self
.
pp_group
.
is_last_rank
:
return
self
.
capture_aux_hidden_states
=
True
self
.
capture_aux_hidden_states
=
True
num_layers
=
self
.
config
.
num_hidden_layers
num_layers
=
self
.
config
.
num_hidden_layers
self
.
model
.
layers_to_capture
=
[
2
,
num_layers
//
2
,
num_layers
-
3
]
self
.
model
.
layers_to_capture
=
[
2
,
num_layers
//
2
,
num_layers
-
3
]
...
...
python/sglang/srt/models/llama4.py
View file @
11383cec
...
@@ -46,7 +46,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
...
@@ -46,7 +46,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.models.llama
import
LlamaForCausalLM
,
LlamaMLP
from
sglang.srt.models.llama
import
LlamaForCausalLM
,
LlamaMLP
from
sglang.srt.utils
import
add_prefix
,
fast_topk
,
get_compiler_backend
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
fast_topk
,
get_compiler_backend
,
make_layers
...
@@ -431,6 +431,7 @@ class Llama4Model(nn.Module):
...
@@ -431,6 +431,7 @@ class Llama4Model(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]:
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]:
if
input_embeds
is
None
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
...
...
python/sglang/srt/models/llama_eagle.py
View file @
11383cec
...
@@ -25,13 +25,14 @@ import torch
...
@@ -25,13 +25,14 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
sglang.srt.distributed
import
get_pp_group
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.models.llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
from
sglang.srt.models.llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
...
@@ -86,6 +87,7 @@ class LlamaModel(nn.Module):
...
@@ -86,6 +87,7 @@ class LlamaModel(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
...
@@ -118,6 +120,7 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
...
@@ -118,6 +120,7 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
pp_group
=
get_pp_group
()
self
.
model
=
LlamaModel
(
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
)
...
...
python/sglang/srt/models/llama_eagle3.py
View file @
11383cec
...
@@ -25,6 +25,7 @@ import torch
...
@@ -25,6 +25,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
sglang.srt.distributed
import
get_pp_group
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
QKVParallelLinear
,
RowParallelLinear
from
sglang.srt.layers.linear
import
QKVParallelLinear
,
RowParallelLinear
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
...
@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
...
@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.models.llama
import
LlamaAttention
,
LlamaDecoderLayer
,
LlamaForCausalLM
from
sglang.srt.models.llama
import
LlamaAttention
,
LlamaDecoderLayer
,
LlamaForCausalLM
...
@@ -118,6 +119,7 @@ class LlamaModel(nn.Module):
...
@@ -118,6 +119,7 @@ class LlamaModel(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
if
input_embeds
is
None
:
embeds
=
self
.
embed_tokens
(
input_ids
)
embeds
=
self
.
embed_tokens
(
input_ids
)
...
@@ -155,6 +157,7 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
...
@@ -155,6 +157,7 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
pp_group
=
get_pp_group
()
if
self
.
config
.
num_hidden_layers
!=
1
:
if
self
.
config
.
num_hidden_layers
!=
1
:
raise
ValueError
(
"EAGLE3 currently only supports 1 layer"
)
raise
ValueError
(
"EAGLE3 currently only supports 1 layer"
)
...
...
python/sglang/srt/server_args.py
View file @
11383cec
...
@@ -78,6 +78,8 @@ class ServerArgs:
...
@@ -78,6 +78,8 @@ class ServerArgs:
# Other runtime options
# Other runtime options
tp_size
:
int
=
1
tp_size
:
int
=
1
pp_size
:
int
=
1
max_micro_batch_size
:
Optional
[
int
]
=
None
stream_interval
:
int
=
1
stream_interval
:
int
=
1
stream_output
:
bool
=
False
stream_output
:
bool
=
False
random_seed
:
Optional
[
int
]
=
None
random_seed
:
Optional
[
int
]
=
None
...
@@ -222,14 +224,18 @@ class ServerArgs:
...
@@ -222,14 +224,18 @@ class ServerArgs:
# Set mem fraction static, which depends on the tensor parallelism size
# Set mem fraction static, which depends on the tensor parallelism size
if
self
.
mem_fraction_static
is
None
:
if
self
.
mem_fraction_static
is
None
:
if
self
.
tp_size
>=
16
:
parallel_size
=
self
.
tp_size
*
self
.
pp_size
self
.
mem_fraction_static
=
0.79
if
gpu_mem
<=
81920
:
elif
self
.
tp_size
>=
8
:
if
parallel_size
>=
16
:
self
.
mem_fraction_static
=
0.81
self
.
mem_fraction_static
=
0.79
elif
self
.
tp_size
>=
4
:
elif
parallel_size
>=
8
:
self
.
mem_fraction_static
=
0.85
self
.
mem_fraction_static
=
0.81
elif
self
.
tp_size
>=
2
:
elif
parallel_size
>=
4
:
self
.
mem_fraction_static
=
0.87
self
.
mem_fraction_static
=
0.85
elif
parallel_size
>=
2
:
self
.
mem_fraction_static
=
0.87
else
:
self
.
mem_fraction_static
=
0.88
else
:
else
:
self
.
mem_fraction_static
=
0.88
self
.
mem_fraction_static
=
0.88
if
gpu_mem
>
96
*
1024
:
if
gpu_mem
>
96
*
1024
:
...
@@ -244,6 +250,8 @@ class ServerArgs:
...
@@ -244,6 +250,8 @@ class ServerArgs:
if
self
.
chunked_prefill_size
is
None
:
if
self
.
chunked_prefill_size
is
None
:
if
gpu_mem
is
not
None
and
gpu_mem
<
25_000
:
if
gpu_mem
is
not
None
and
gpu_mem
<
25_000
:
self
.
chunked_prefill_size
=
2048
self
.
chunked_prefill_size
=
2048
elif
self
.
disaggregation_mode
!=
"null"
:
self
.
chunked_prefill_size
=
16384
else
:
else
:
self
.
chunked_prefill_size
=
8192
self
.
chunked_prefill_size
=
8192
assert
self
.
chunked_prefill_size
%
self
.
page_size
==
0
assert
self
.
chunked_prefill_size
%
self
.
page_size
==
0
...
@@ -643,6 +651,19 @@ class ServerArgs:
...
@@ -643,6 +651,19 @@ class ServerArgs:
default
=
ServerArgs
.
tp_size
,
default
=
ServerArgs
.
tp_size
,
help
=
"The tensor parallelism size."
,
help
=
"The tensor parallelism size."
,
)
)
parser
.
add_argument
(
"--pipeline-parallel-size"
,
"--pp-size"
,
type
=
int
,
default
=
ServerArgs
.
pp_size
,
help
=
"The pipeline parallelism size."
,
)
parser
.
add_argument
(
"--max-micro-batch-size"
,
type
=
int
,
default
=
ServerArgs
.
max_micro_batch_size
,
help
=
"The maximum micro batch size in pipeline parallelism."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--stream-interval"
,
"--stream-interval"
,
type
=
int
,
type
=
int
,
...
@@ -1232,6 +1253,7 @@ class ServerArgs:
...
@@ -1232,6 +1253,7 @@ class ServerArgs:
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
args
.
tp_size
=
args
.
tensor_parallel_size
args
.
tp_size
=
args
.
tensor_parallel_size
args
.
pp_size
=
args
.
pipeline_parallel_size
args
.
dp_size
=
args
.
data_parallel_size
args
.
dp_size
=
args
.
data_parallel_size
args
.
ep_size
=
args
.
expert_parallel_size
args
.
ep_size
=
args
.
expert_parallel_size
attrs
=
[
attr
.
name
for
attr
in
dataclasses
.
fields
(
cls
)]
attrs
=
[
attr
.
name
for
attr
in
dataclasses
.
fields
(
cls
)]
...
@@ -1245,8 +1267,19 @@ class ServerArgs:
...
@@ -1245,8 +1267,19 @@ class ServerArgs:
def
check_server_args
(
self
):
def
check_server_args
(
self
):
assert
(
assert
(
self
.
tp_size
%
self
.
nnodes
==
0
self
.
tp_size
*
self
.
pp_size
),
"tp_size must be divisible by number of nodes"
)
%
self
.
nnodes
==
0
,
"tp_size must be divisible by number of nodes"
# FIXME pp constraints
if
self
.
pp_size
>
1
:
logger
.
warning
(
f
"Turn off overlap scheule for pipeline parallelism."
)
self
.
disable_overlap_schedule
=
True
assert
(
self
.
disable_overlap_schedule
and
self
.
speculative_algorithm
is
None
and
not
self
.
enable_mixed_chunk
),
"Pipeline parallelism is not compatible with overlap schedule, speculative decoding, mixed chunked prefill."
assert
not
(
assert
not
(
self
.
dp_size
>
1
and
self
.
nnodes
!=
1
and
not
self
.
enable_dp_attention
self
.
dp_size
>
1
and
self
.
nnodes
!=
1
and
not
self
.
enable_dp_attention
),
"multi-node data parallel is not supported unless dp attention!"
),
"multi-node data parallel is not supported unless dp attention!"
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
11383cec
...
@@ -106,11 +106,12 @@ class EAGLEWorker(TpModelWorker):
...
@@ -106,11 +106,12 @@ class EAGLEWorker(TpModelWorker):
# Init draft worker
# Init draft worker
with
empty_context
():
with
empty_context
():
super
().
__init__
(
super
().
__init__
(
server_args
=
server_args
,
gpu_id
=
gpu_id
,
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
server_args
=
server_args
,
pp_rank
=
0
,
# FIXME
nccl_port
=
nccl_port
,
dp_rank
=
dp_rank
,
dp_rank
=
dp_rank
,
nccl_port
=
nccl_port
,
is_draft_worker
=
True
,
is_draft_worker
=
True
,
req_to_token_pool
=
self
.
req_to_token_pool
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment