Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
11383cec
Unverified
Commit
11383cec
authored
Apr 30, 2025
by
Ying Sheng
Committed by
GitHub
Apr 30, 2025
Browse files
[PP] Add pipeline parallelism (#5724)
parent
e97e57e6
Changes
25
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
894 additions
and
303 deletions
+894
-303
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+2
-0
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+36
-19
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+5
-2
python/sglang/srt/layers/utils.py
python/sglang/srt/layers/utils.py
+35
-0
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+52
-34
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+25
-15
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+262
-59
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+1
-1
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+50
-16
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+9
-3
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+70
-36
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+67
-18
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+31
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+101
-54
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+92
-30
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+2
-1
python/sglang/srt/models/llama_eagle.py
python/sglang/srt/models/llama_eagle.py
+4
-1
python/sglang/srt/models/llama_eagle3.py
python/sglang/srt/models/llama_eagle3.py
+4
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+43
-10
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+3
-2
No files found.
python/sglang/bench_one_batch.py
View file @
11383cec
...
...
@@ -154,6 +154,8 @@ def load_model(server_args, port_args, tp_rank):
gpu_id
=
tp_rank
,
tp_rank
=
tp_rank
,
tp_size
=
server_args
.
tp_size
,
pp_rank
=
0
,
pp_size
=
1
,
nccl_port
=
port_args
.
nccl_port
,
server_args
=
server_args
,
)
...
...
python/sglang/srt/entrypoints/engine.py
View file @
11383cec
...
...
@@ -126,7 +126,6 @@ class Engine(EngineBase):
server_args
=
server_args
,
port_args
=
port_args
,
)
self
.
server_args
=
server_args
self
.
tokenizer_manager
=
tokenizer_manager
self
.
scheduler_info
=
scheduler_info
...
...
@@ -301,7 +300,6 @@ class Engine(EngineBase):
internal_states
=
loop
.
run_until_complete
(
self
.
tokenizer_manager
.
get_internal_state
()
)
return
{
**
dataclasses
.
asdict
(
self
.
tokenizer_manager
.
server_args
),
**
self
.
scheduler_info
,
...
...
@@ -520,25 +518,44 @@ def _launch_subprocesses(
)
scheduler_pipe_readers
=
[]
tp_size_per_node
=
server_args
.
tp_size
//
server_args
.
nnodes
nnodes_per_tp_group
=
max
(
server_args
.
nnodes
//
server_args
.
pp_size
,
1
)
tp_size_per_node
=
server_args
.
tp_size
//
nnodes_per_tp_group
tp_rank_range
=
range
(
tp_size_per_node
*
server_args
.
node_rank
,
tp_size_per_node
*
(
server_args
.
node_rank
+
1
),
tp_size_per_node
*
(
server_args
.
node_rank
%
nnodes_per_tp_group
)
,
tp_size_per_node
*
(
server_args
.
node_rank
%
nnodes_per_tp_group
+
1
),
)
for
tp_rank
in
tp_rank_range
:
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
gpu_id
=
(
server_args
.
base_gpu_id
+
(
tp_rank
%
tp_size_per_node
)
*
server_args
.
gpu_id_step
)
proc
=
mp
.
Process
(
target
=
run_scheduler_process
,
args
=
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
None
,
writer
),
)
with
memory_saver_adapter
.
configure_subprocess
():
proc
.
start
()
scheduler_procs
.
append
(
proc
)
scheduler_pipe_readers
.
append
(
reader
)
pp_size_per_node
=
max
(
server_args
.
pp_size
//
server_args
.
nnodes
,
1
)
pp_rank_range
=
range
(
pp_size_per_node
*
(
server_args
.
node_rank
//
nnodes_per_tp_group
),
pp_size_per_node
*
(
server_args
.
node_rank
//
nnodes_per_tp_group
+
1
),
)
for
pp_rank
in
pp_rank_range
:
for
tp_rank
in
tp_rank_range
:
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
gpu_id
=
(
server_args
.
base_gpu_id
+
((
pp_rank
%
pp_size_per_node
)
*
tp_size_per_node
)
+
(
tp_rank
%
tp_size_per_node
)
*
server_args
.
gpu_id_step
)
proc
=
mp
.
Process
(
target
=
run_scheduler_process
,
args
=
(
server_args
,
port_args
,
gpu_id
,
tp_rank
,
pp_rank
,
None
,
writer
,
),
)
with
memory_saver_adapter
.
configure_subprocess
():
proc
.
start
()
scheduler_procs
.
append
(
proc
)
scheduler_pipe_readers
.
append
(
reader
)
else
:
# Launch the data parallel controller
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
...
...
python/sglang/srt/layers/dp_attention.py
View file @
11383cec
...
...
@@ -43,6 +43,7 @@ def initialize_dp_attention(
tp_rank
:
int
,
tp_size
:
int
,
dp_size
:
int
,
pp_size
:
int
,
):
global
_ATTN_TP_GROUP
,
_ATTN_TP_RANK
,
_ATTN_TP_SIZE
,
_DP_RANK
,
_DP_SIZE
...
...
@@ -53,17 +54,19 @@ def initialize_dp_attention(
)
if
enable_dp_attention
:
local_rank
=
tp_rank
%
(
tp_size
//
dp_size
)
_DP_SIZE
=
dp_size
else
:
local_rank
=
tp_rank
_DP_SIZE
=
1
tp_group
=
get_tp_group
()
_ATTN_TP_GROUP
=
GroupCoordinator
(
[
list
(
range
(
head
,
head
+
_ATTN_TP_SIZE
))
for
head
in
range
(
0
,
tp_size
,
_ATTN_TP_SIZE
)
for
head
in
range
(
0
,
pp_size
*
tp_size
,
_ATTN_TP_SIZE
)
],
tp_group
.
local_rank
,
local_rank
,
torch
.
distributed
.
get_backend
(
tp_group
.
device_group
),
SYNC_TOKEN_IDS_ACROSS_TP
,
False
,
...
...
python/sglang/srt/layers/utils.py
0 → 100644
View file @
11383cec
import
logging
import
re
import
torch
logger
=
logging
.
getLogger
(
__name__
)
def
get_layer_id
(
weight_name
):
# example weight name: model.layers.10.self_attn.qkv_proj.weight
match
=
re
.
search
(
r
"layers\.(\d+)\."
,
weight_name
)
if
match
:
return
int
(
match
.
group
(
1
))
return
None
class
PPMissingLayer
(
torch
.
nn
.
Identity
):
# Adapted from
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
"""
A placeholder layer for missing layers in a pipeline parallel model.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
()
self
.
return_tuple
=
kwargs
.
get
(
"return_tuple"
,
False
)
def
forward
(
self
,
*
args
,
**
kwargs
):
"""
Return the first arg from args or the first value from kwargs.
Wraps the input in a tuple if `self.return_tuple` is True.
"""
input
=
args
[
0
]
if
args
else
next
(
iter
(
kwargs
.
values
()))
return
(
input
,)
if
self
.
return_tuple
else
input
python/sglang/srt/managers/data_parallel_controller.py
View file @
11383cec
...
...
@@ -181,44 +181,62 @@ class DataParallelController:
enable
=
server_args
.
enable_memory_saver
)
# Launch tensor parallel scheduler processes
scheduler_pipe_readers
=
[]
tp_size_per_node
=
server_args
.
tp_size
//
server_args
.
nnodes
nnodes_per_tp_group
=
max
(
server_args
.
nnodes
//
server_args
.
pp_size
,
1
)
tp_size_per_node
=
server_args
.
tp_size
//
nnodes_per_tp_group
tp_rank_range
=
range
(
tp_size_per_node
*
server_args
.
node_rank
,
tp_size_per_node
*
(
server_args
.
node_rank
+
1
),
tp_size_per_node
*
(
server_args
.
node_rank
%
nnodes_per_tp_group
),
tp_size_per_node
*
(
server_args
.
node_rank
%
nnodes_per_tp_group
+
1
),
)
pp_size_per_node
=
max
(
server_args
.
pp_size
//
server_args
.
nnodes
,
1
)
pp_rank_range
=
range
(
pp_size_per_node
*
(
server_args
.
node_rank
//
nnodes_per_tp_group
),
pp_size_per_node
*
(
server_args
.
node_rank
//
nnodes_per_tp_group
+
1
),
)
for
tp_rank
in
tp_rank_range
:
rank_port_args
=
port_args
if
server_args
.
enable_dp_attention
:
# dp attention has different sharding logic
_
,
_
,
dp_rank
=
compute_dp_attention_world_info
(
server_args
.
enable_dp_attention
,
tp_rank
,
server_args
.
tp_size
,
server_args
.
dp_size
,
for
pp_rank
in
pp_rank_range
:
for
tp_rank
in
tp_rank_range
:
rank_port_args
=
port_args
if
server_args
.
enable_dp_attention
:
# dp attention has different sharding logic
_
,
_
,
dp_rank
=
compute_dp_attention_world_info
(
server_args
.
enable_dp_attention
,
tp_rank
,
server_args
.
tp_size
,
server_args
.
dp_size
,
)
# compute zmq ports for this dp rank
rank_port_args
=
PortArgs
.
init_new
(
server_args
,
dp_rank
)
# Data parallelism resues the tensor parallelism group,
# so all dp ranks should use the same nccl port.
rank_port_args
.
nccl_port
=
port_args
.
nccl_port
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
gpu_id
=
(
server_args
.
base_gpu_id
+
base_gpu_id
+
((
pp_rank
%
pp_size_per_node
)
*
tp_size_per_node
)
+
(
tp_rank
%
tp_size_per_node
)
*
server_args
.
gpu_id_step
)
# compute zmq ports for this dp rank
rank_port_args
=
PortArgs
.
init_new
(
server_args
,
dp_rank
)
# Data parallelism resues the tensor parallelism group,
# so all dp ranks should use the same nccl port.
rank_port_args
.
nccl_port
=
port_args
.
nccl_port
reader
,
writer
=
mp
.
Pipe
(
duplex
=
False
)
gpu_id
=
(
server_args
.
base_gpu_id
+
base_gpu_id
+
(
tp_rank
%
tp_size_per_node
)
*
server_args
.
gpu_id_step
)
proc
=
mp
.
Process
(
target
=
run_scheduler_process
,
args
=
(
server_args
,
rank_port_args
,
gpu_id
,
tp_rank
,
dp_rank
,
writer
),
)
with
memory_saver_adapter
.
configure_subprocess
():
proc
.
start
()
self
.
scheduler_procs
.
append
(
proc
)
scheduler_pipe_readers
.
append
(
reader
)
proc
=
mp
.
Process
(
target
=
run_scheduler_process
,
args
=
(
server_args
,
rank_port_args
,
gpu_id
,
tp_rank
,
pp_rank
,
dp_rank
,
writer
,
),
)
with
memory_saver_adapter
.
configure_subprocess
():
proc
.
start
()
self
.
scheduler_procs
.
append
(
proc
)
scheduler_pipe_readers
.
append
(
reader
)
# Wait for model to finish loading
scheduler_info
=
[]
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
11383cec
...
...
@@ -66,23 +66,24 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access
global_server_args_dict
=
{
"attention_backend"
:
ServerArgs
.
attention_backend
,
"sampling_backend"
:
ServerArgs
.
sampling_backend
,
"triton_attention_reduce_in_fp32"
:
ServerArgs
.
triton_attention_reduce_in_fp32
,
"torchao_config"
:
ServerArgs
.
torchao_config
,
"enable_nan_detection"
:
ServerArgs
.
enable_nan_detection
,
"enable_dp_attention"
:
ServerArgs
.
enable_dp_attention
,
"enable_ep_moe"
:
ServerArgs
.
enable_ep_moe
,
"enable_deepep_moe"
:
ServerArgs
.
enable_deepep_moe
,
"chunked_prefill_size"
:
ServerArgs
.
chunked_prefill_size
,
"deepep_mode"
:
ServerArgs
.
deepep_mode
,
"device"
:
ServerArgs
.
device
,
"speculative_accept_threshold_single"
:
ServerArgs
.
speculative_accept_threshold_single
,
"speculative_accept_threshold_acc"
:
ServerArgs
.
speculative_accept_threshold_acc
,
"disable_chunked_prefix_cache"
:
ServerArgs
.
disable_chunked_prefix_cache
,
"disable_radix_cache"
:
ServerArgs
.
disable_radix_cache
,
"enable_deepep_moe"
:
ServerArgs
.
enable_deepep_moe
,
"enable_dp_attention"
:
ServerArgs
.
enable_dp_attention
,
"enable_ep_moe"
:
ServerArgs
.
enable_ep_moe
,
"enable_nan_detection"
:
ServerArgs
.
enable_nan_detection
,
"flashinfer_mla_disable_ragged"
:
ServerArgs
.
flashinfer_mla_disable_ragged
,
"max_micro_batch_size"
:
ServerArgs
.
max_micro_batch_size
,
"moe_dense_tp_size"
:
ServerArgs
.
moe_dense_tp_size
,
"chunked_prefill_size"
:
ServerArgs
.
chunked_prefill_size
,
"n_share_experts_fusion"
:
ServerArgs
.
n_share_experts_fusion
,
"disable_chunked_prefix_cache"
:
ServerArgs
.
disable_chunked_prefix_cache
,
"sampling_backend"
:
ServerArgs
.
sampling_backend
,
"speculative_accept_threshold_acc"
:
ServerArgs
.
speculative_accept_threshold_acc
,
"speculative_accept_threshold_single"
:
ServerArgs
.
speculative_accept_threshold_single
,
"torchao_config"
:
ServerArgs
.
torchao_config
,
"triton_attention_reduce_in_fp32"
:
ServerArgs
.
triton_attention_reduce_in_fp32
,
}
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -728,6 +729,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Events
launch_done
:
Optional
[
threading
.
Event
]
=
None
# For chunked prefill in PP
chunked_req
:
Optional
[
Req
]
=
None
# Sampling info
sampling_info
:
SamplingBatchInfo
=
None
next_batch_sampling_info
:
SamplingBatchInfo
=
None
...
...
@@ -761,7 +765,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# For extend and mixed chunekd prefill
prefix_lens
:
List
[
int
]
=
None
extend_lens
:
List
[
int
]
=
None
extend_num_tokens
:
int
=
None
extend_num_tokens
:
Optional
[
int
]
=
None
decoding_reqs
:
List
[
Req
]
=
None
extend_logprob_start_lens
:
List
[
int
]
=
None
# It comes empty list if logprob is not required.
...
...
@@ -803,6 +807,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
enable_overlap
:
bool
,
spec_algorithm
:
SpeculativeAlgorithm
,
enable_custom_logit_processor
:
bool
,
chunked_req
:
Optional
[
Req
]
=
None
,
):
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
)
...
...
@@ -820,6 +825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
spec_algorithm
=
spec_algorithm
,
enable_custom_logit_processor
=
enable_custom_logit_processor
,
return_hidden_states
=
any
(
req
.
return_hidden_states
for
req
in
reqs
),
chunked_req
=
chunked_req
,
)
def
batch_size
(
self
):
...
...
@@ -1236,7 +1242,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def
retract_decode
(
self
,
server_args
:
ServerArgs
):
"""Retract the decoding requests when there is not enough memory."""
sorted_indices
=
[
i
for
i
in
range
(
len
(
self
.
reqs
))
]
sorted_indices
=
list
(
range
(
len
(
self
.
reqs
))
)
# TODO(lsyin): improve retraction policy for radix cache
# For spec decoding, filter_batch API can only filter
...
...
@@ -1413,15 +1419,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def
filter_batch
(
self
,
chunked_req_to_exclude
:
Optional
[
Req
]
=
None
,
chunked_req_to_exclude
:
Optional
[
Union
[
Req
,
List
[
Req
]]
]
=
None
,
keep_indices
:
Optional
[
List
[
int
]]
=
None
,
):
if
keep_indices
is
None
:
if
isinstance
(
chunked_req_to_exclude
,
Req
):
chunked_req_to_exclude
=
[
chunked_req_to_exclude
]
elif
chunked_req_to_exclude
is
None
:
chunked_req_to_exclude
=
[]
keep_indices
=
[
i
for
i
in
range
(
len
(
self
.
reqs
))
if
not
self
.
reqs
[
i
].
finished
()
and
self
.
reqs
[
i
]
i
s
not
chunked_req_to_exclude
and
not
self
.
reqs
[
i
]
i
n
chunked_req_to_exclude
]
if
keep_indices
is
None
or
len
(
keep_indices
)
==
0
:
...
...
python/sglang/srt/managers/scheduler.py
View file @
11383cec
This diff is collapsed.
Click to expand it.
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
11383cec
...
...
@@ -278,7 +278,7 @@ class SchedulerOutputProcessorMixin:
self
.
attn_tp_rank
==
0
and
self
.
forward_ct_decode
%
self
.
server_args
.
decode_log_interval
==
0
):
self
.
log_decode_stats
()
self
.
log_decode_stats
(
running_batch
=
batch
)
def
add_input_logprob_return_values
(
self
:
Scheduler
,
...
...
python/sglang/srt/managers/tp_worker.py
View file @
11383cec
...
...
@@ -15,11 +15,12 @@
import
logging
import
threading
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
,
Union
import
torch
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.distributed
import
get_pp_group
,
get_tp_group
,
get_world_group
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
(
...
...
@@ -31,7 +32,7 @@ from sglang.srt.managers.io_struct import (
)
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
MultiprocessingSerializer
,
broadcast_pyobj
,
set_random_seed
...
...
@@ -47,6 +48,7 @@ class TpModelWorker:
server_args
:
ServerArgs
,
gpu_id
:
int
,
tp_rank
:
int
,
pp_rank
:
int
,
dp_rank
:
Optional
[
int
],
nccl_port
:
int
,
is_draft_worker
:
bool
=
False
,
...
...
@@ -54,7 +56,9 @@ class TpModelWorker:
token_to_kv_pool_allocator
:
Optional
[
TokenToKVPoolAllocator
]
=
None
,
):
# Parse args
self
.
tp_size
=
server_args
.
tp_size
self
.
tp_rank
=
tp_rank
self
.
pp_rank
=
pp_rank
# Init model and tokenizer
self
.
model_config
=
ModelConfig
(
...
...
@@ -73,12 +77,15 @@ class TpModelWorker:
quantization
=
server_args
.
quantization
,
is_draft_model
=
is_draft_worker
,
)
self
.
model_runner
=
ModelRunner
(
model_config
=
self
.
model_config
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
tp_size
=
server_args
.
tp_size
,
pp_rank
=
pp_rank
,
pp_size
=
server_args
.
pp_size
,
nccl_port
=
nccl_port
,
server_args
=
server_args
,
is_draft_worker
=
is_draft_worker
,
...
...
@@ -105,6 +112,10 @@ class TpModelWorker:
)
self
.
device
=
self
.
model_runner
.
device
# Init nccl groups
self
.
pp_group
=
get_pp_group
()
self
.
world_group
=
get_world_group
()
# Profile number of tokens
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
self
.
max_prefill_tokens
=
server_args
.
max_prefill_tokens
...
...
@@ -130,8 +141,9 @@ class TpModelWorker:
# Sync random seed across TP workers
self
.
random_seed
=
broadcast_pyobj
(
[
server_args
.
random_seed
],
self
.
tp_rank
,
self
.
model_runner
.
tp_group
.
cpu_group
,
self
.
tp_size
*
self
.
pp_rank
+
tp_rank
,
self
.
world_group
.
cpu_group
,
src
=
self
.
world_group
.
ranks
[
0
],
)[
0
]
set_random_seed
(
self
.
random_seed
)
...
...
@@ -156,11 +168,14 @@ class TpModelWorker:
def
get_pad_input_ids_func
(
self
):
return
getattr
(
self
.
model_runner
.
model
,
"pad_input_ids"
,
None
)
def
get_tp_cpu_group
(
self
):
return
self
.
model_runner
.
tp_group
.
cpu_group
def
get_tp_group
(
self
):
return
self
.
model_runner
.
tp_group
def
get_attention_tp_group
(
self
):
return
self
.
model_runner
.
attention_tp_group
def
get_attention_tp_cpu_group
(
self
):
return
self
.
model_runner
.
attention_tp_group
.
cpu_group
return
getattr
(
self
.
model_runner
.
attention_tp_group
,
"
cpu_group
"
,
None
)
def
get_memory_pool
(
self
):
return
(
...
...
@@ -172,19 +187,38 @@ class TpModelWorker:
self
,
model_worker_batch
:
ModelWorkerBatch
,
skip_sample
:
bool
=
False
,
)
->
Tuple
[
LogitsProcessorOutput
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
if
model_worker_batch
.
launch_done
is
not
None
:
model_worker_batch
.
launch_done
.
set
()
pp_proxy_tensors
=
None
if
not
self
.
pp_group
.
is_first_rank
:
pp_proxy_tensors
=
PPProxyTensors
(
self
.
pp_group
.
recv_tensor_dict
(
all_gather_group
=
self
.
get_attention_tp_group
()
)
)
if
self
.
pp_group
.
is_last_rank
:
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
if
model_worker_batch
.
launch_done
is
not
None
:
model_worker_batch
.
launch_done
.
set
()
if
skip_sample
:
next_token_ids
=
None
else
:
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
model_worker_batch
)
if
skip_sample
:
next_token_ids
=
None
else
:
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
model_worker_batch
)
return
logits_output
,
next_token_ids
return
logits_output
,
next_token_ids
else
:
pp_proxy_tensors
=
self
.
model_runner
.
forward
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
return
pp_proxy_tensors
.
tensors
,
None
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
11383cec
...
...
@@ -56,11 +56,14 @@ class TpModelWorkerClient:
server_args
:
ServerArgs
,
gpu_id
:
int
,
tp_rank
:
int
,
pp_rank
:
int
,
dp_rank
:
Optional
[
int
],
nccl_port
:
int
,
):
# Load the model
self
.
worker
=
TpModelWorker
(
server_args
,
gpu_id
,
tp_rank
,
dp_rank
,
nccl_port
)
self
.
worker
=
TpModelWorker
(
server_args
,
gpu_id
,
tp_rank
,
pp_rank
,
dp_rank
,
nccl_port
)
self
.
max_running_requests
=
self
.
worker
.
max_running_requests
self
.
device
=
self
.
worker
.
device
self
.
gpu_id
=
gpu_id
...
...
@@ -91,8 +94,11 @@ class TpModelWorkerClient:
def
get_pad_input_ids_func
(
self
):
return
self
.
worker
.
get_pad_input_ids_func
()
def
get_tp_cpu_group
(
self
):
return
self
.
worker
.
get_tp_cpu_group
()
def
get_tp_group
(
self
):
return
self
.
worker
.
get_tp_group
()
def
get_attention_tp_group
(
self
):
return
self
.
worker
.
get_attention_tp_group
()
def
get_attention_tp_cpu_group
(
self
):
return
self
.
worker
.
get_attention_tp_cpu_group
()
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
11383cec
...
...
@@ -214,6 +214,8 @@ class MHATokenToKVPool(KVCache):
layer_num
:
int
,
device
:
str
,
enable_memory_saver
:
bool
,
start_layer
:
Optional
[
int
]
=
None
,
end_layer
:
Optional
[
int
]
=
None
,
):
self
.
size
=
size
self
.
page_size
=
page_size
...
...
@@ -232,6 +234,8 @@ class MHATokenToKVPool(KVCache):
self
.
head_dim
=
head_dim
self
.
layer_num
=
layer_num
self
.
_create_buffers
()
self
.
start_layer
=
start_layer
or
0
self
.
end_layer
=
end_layer
or
layer_num
-
1
self
.
layer_transfer_counter
=
None
self
.
capture_mode
=
False
...
...
@@ -281,6 +285,8 @@ class MHATokenToKVPool(KVCache):
# for disagg
def
get_contiguous_buf_infos
(
self
):
# layer_num x [seq_len, head_num, head_dim]
# layer_num x [page_num, page_size, head_num, head_dim]
kv_data_ptrs
=
[
self
.
get_key_buffer
(
i
).
data_ptr
()
for
i
in
range
(
self
.
layer_num
)
]
+
[
self
.
get_value_buffer
(
i
).
data_ptr
()
for
i
in
range
(
self
.
layer_num
)]
...
...
@@ -320,24 +326,24 @@ class MHATokenToKVPool(KVCache):
# transfer prepared data from host to device
flat_data
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
k_data
,
v_data
=
flat_data
[
0
],
flat_data
[
1
]
self
.
k_buffer
[
layer_id
][
indices
]
=
k_data
self
.
v_buffer
[
layer_id
][
indices
]
=
v_data
self
.
k_buffer
[
layer_id
-
self
.
start_layer
][
indices
]
=
k_data
self
.
v_buffer
[
layer_id
-
self
.
start_layer
][
indices
]
=
v_data
def
get_key_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
)
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
k_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
k_buffer
[
layer_id
]
return
self
.
k_buffer
[
layer_id
-
self
.
start_layer
].
view
(
self
.
dtype
)
return
self
.
k_buffer
[
layer_id
-
self
.
start_layer
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
)
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
v_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
v_buffer
[
layer_id
]
return
self
.
v_buffer
[
layer_id
-
self
.
start_layer
].
view
(
self
.
dtype
)
return
self
.
v_buffer
[
layer_id
-
self
.
start_layer
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
get_key_buffer
(
layer_id
),
self
.
get_value_buffer
(
layer_id
)
...
...
@@ -369,12 +375,12 @@ class MHATokenToKVPool(KVCache):
current_stream
=
self
.
device_module
.
current_stream
()
self
.
alt_stream
.
wait_stream
(
current_stream
)
with
self
.
device_module
.
stream
(
self
.
alt_stream
):
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
self
.
k_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_v
current_stream
.
wait_stream
(
self
.
alt_stream
)
else
:
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
self
.
k_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_v
@
torch
.
compile
...
...
@@ -484,6 +490,8 @@ class MLATokenToKVPool(KVCache):
layer_num
:
int
,
device
:
str
,
enable_memory_saver
:
bool
,
start_layer
:
Optional
[
int
]
=
None
,
end_layer
:
Optional
[
int
]
=
None
,
):
self
.
size
=
size
self
.
page_size
=
page_size
...
...
@@ -497,6 +505,8 @@ class MLATokenToKVPool(KVCache):
self
.
kv_lora_rank
=
kv_lora_rank
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
layer_num
=
layer_num
self
.
start_layer
=
start_layer
or
0
self
.
end_layer
=
end_layer
or
layer_num
-
1
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
enable_memory_saver
...
...
@@ -540,19 +550,21 @@ class MLATokenToKVPool(KVCache):
def
get_key_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
)
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
kv_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
kv_buffer
[
layer_id
]
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
].
view
(
self
.
dtype
)
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
)
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
kv_buffer
[
layer_id
][...,
:
self
.
kv_lora_rank
].
view
(
self
.
dtype
)
return
self
.
kv_buffer
[
layer_id
][...,
:
self
.
kv_lora_rank
]
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
][
...,
:
self
.
kv_lora_rank
].
view
(
self
.
dtype
)
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
][...,
:
self
.
kv_lora_rank
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
get_key_buffer
(
layer_id
),
self
.
get_value_buffer
(
layer_id
)
...
...
@@ -568,9 +580,11 @@ class MLATokenToKVPool(KVCache):
if
cache_k
.
dtype
!=
self
.
dtype
:
cache_k
=
cache_k
.
to
(
self
.
dtype
)
if
self
.
store_dtype
!=
self
.
dtype
:
self
.
kv_buffer
[
layer_id
][
loc
]
=
cache_k
.
view
(
self
.
store_dtype
)
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_k
.
view
(
self
.
store_dtype
)
else
:
self
.
kv_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_k
def
set_mla_kv_buffer
(
self
,
...
...
@@ -605,7 +619,7 @@ class MLATokenToKVPool(KVCache):
def
transfer_per_layer
(
self
,
indices
,
flat_data
,
layer_id
):
# transfer prepared data from host to device
flat_data
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
self
.
kv_buffer
[
layer_id
][
indices
]
=
flat_data
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
][
indices
]
=
flat_data
class
DoubleSparseTokenToKVPool
(
KVCache
):
...
...
@@ -620,6 +634,8 @@ class DoubleSparseTokenToKVPool(KVCache):
device
:
str
,
heavy_channel_num
:
int
,
enable_memory_saver
:
bool
,
start_layer
:
Optional
[
int
]
=
None
,
end_layer
:
Optional
[
int
]
=
None
,
):
self
.
size
=
size
self
.
page_size
=
page_size
...
...
@@ -657,17 +673,23 @@ class DoubleSparseTokenToKVPool(KVCache):
for
_
in
range
(
layer_num
)
]
self
.
start_layer
=
start_layer
or
0
self
.
end_layer
=
end_layer
or
layer_num
-
1
def
get_key_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
]
return
self
.
k_buffer
[
layer_id
-
self
.
start_layer
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
return
self
.
v_buffer
[
layer_id
]
return
self
.
v_buffer
[
layer_id
-
self
.
start_layer
]
def
get_label_buffer
(
self
,
layer_id
:
int
):
return
self
.
label_buffer
[
layer_id
]
return
self
.
label_buffer
[
layer_id
-
self
.
start_layer
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
],
self
.
v_buffer
[
layer_id
]
return
(
self
.
k_buffer
[
layer_id
-
self
.
start_layer
],
self
.
v_buffer
[
layer_id
-
self
.
start_layer
],
)
def
set_kv_buffer
(
self
,
...
...
@@ -679,9 +701,9 @@ class DoubleSparseTokenToKVPool(KVCache):
):
# NOTE(Andy): ignore the dtype check
layer_id
=
layer
.
layer_id
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
self
.
label_buffer
[
layer_id
][
loc
]
=
cache_label
self
.
k_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_v
self
.
label_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_label
def
get_flat_data
(
self
,
indices
):
pass
...
...
@@ -930,7 +952,7 @@ class MHATokenToKVPoolHost(HostKVCache):
return
self
.
kv_buffer
[:,
:,
indices
]
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
return
self
.
kv_buffer
[:,
layer_id
,
indices
]
return
self
.
kv_buffer
[:,
layer_id
-
self
.
start_layer
,
indices
]
def
assign_flat_data
(
self
,
indices
,
flat_data
):
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
...
...
@@ -955,12 +977,20 @@ class MHATokenToKVPoolHost(HostKVCache):
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
device_pool
.
k_buffer
[
layer_id
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
0
,
layer_id
,
h_index
:
h_index
+
self
.
page_size
],
device_pool
.
k_buffer
[
layer_id
-
self
.
start_layer
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
0
,
layer_id
-
self
.
start_layer
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
device_pool
.
v_buffer
[
layer_id
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
1
,
layer_id
,
h_index
:
h_index
+
self
.
page_size
],
device_pool
.
v_buffer
[
layer_id
-
self
.
start_layer
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
1
,
layer_id
-
self
.
start_layer
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
...
...
@@ -1015,7 +1045,7 @@ class MLATokenToKVPoolHost(HostKVCache):
return
self
.
kv_buffer
[:,
indices
]
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
return
self
.
kv_buffer
[
layer_id
,
indices
]
return
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
,
indices
]
def
assign_flat_data
(
self
,
indices
,
flat_data
):
self
.
kv_buffer
[:,
indices
]
=
flat_data
...
...
@@ -1036,7 +1066,11 @@ class MLATokenToKVPoolHost(HostKVCache):
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
device_pool
.
kv_buffer
[
layer_id
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
layer_id
,
h_index
:
h_index
+
self
.
page_size
],
device_pool
.
kv_buffer
[
layer_id
-
self
.
start_layer
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
11383cec
...
...
@@ -16,6 +16,7 @@
from
__future__
import
annotations
import
bisect
import
inspect
import
os
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Callable
...
...
@@ -33,12 +34,14 @@ from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode
,
ForwardBatch
,
ForwardMode
,
PPProxyTensors
,
)
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
get_device_memory_capacity
,
is_hip
,
rank0_log
,
)
if
TYPE_CHECKING
:
...
...
@@ -188,10 +191,11 @@ class CudaGraphRunner:
self
.
speculative_algorithm
=
model_runner
.
server_args
.
speculative_algorithm
self
.
tp_size
=
model_runner
.
server_args
.
tp_size
self
.
dp_size
=
model_runner
.
server_args
.
dp_size
self
.
pp_size
=
model_runner
.
server_args
.
pp_size
# Batch sizes to capture
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
rank0_log
(
f
"Capture cuda graph bs
{
self
.
capture_bs
}
"
)
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
self
.
num_tokens_per_bs
=
1
...
...
@@ -234,6 +238,19 @@ class CudaGraphRunner:
self
.
positions
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int64
)
# pipeline parallelism
if
self
.
pp_size
>
1
:
self
.
pp_proxy_tensors
=
{
"hidden_states"
:
torch
.
zeros
(
(
self
.
max_bs
,
self
.
model_runner
.
model_config
.
hidden_size
),
dtype
=
torch
.
bfloat16
,
),
"residual"
:
torch
.
zeros
(
(
self
.
max_bs
,
self
.
model_runner
.
model_config
.
hidden_size
),
dtype
=
torch
.
bfloat16
,
),
}
# Speculative_inference
if
(
model_runner
.
spec_algorithm
.
is_eagle3
()
...
...
@@ -384,6 +401,12 @@ class CudaGraphRunner:
encoder_lens
=
None
mrope_positions
=
self
.
mrope_positions
[:,
:
bs
]
# pipeline parallelism
if
self
.
pp_size
>
1
:
pp_proxy_tensors
=
PPProxyTensors
(
{
k
:
v
[:
num_tokens
]
for
k
,
v
in
self
.
pp_proxy_tensors
.
items
()}
)
if
self
.
enable_dp_attention
or
self
.
enable_sp_layernorm
:
self
.
global_num_tokens_gpu
.
copy_
(
torch
.
tensor
(
...
...
@@ -456,8 +479,20 @@ class CudaGraphRunner:
# Clean intermediate result cache for DP attention
forward_batch
.
dp_local_start_pos
=
forward_batch
.
dp_local_num_tokens
=
None
logits_output
=
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
)
return
logits_output
.
next_token_logits
,
logits_output
.
hidden_states
kwargs
=
{}
if
(
self
.
pp_size
>
1
and
"pp_proxy_tensors"
in
inspect
.
signature
(
forward
).
parameters
):
kwargs
[
"pp_proxy_tensors"
]
=
pp_proxy_tensors
logits_output_or_pp_proxy_tensors
=
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
,
**
kwargs
,
)
return
logits_output_or_pp_proxy_tensors
for
_
in
range
(
2
):
torch
.
cuda
.
synchronize
()
...
...
@@ -490,7 +525,11 @@ class CudaGraphRunner:
self
.
capture_hidden_mode
=
hidden_mode_from_spec_info
self
.
capture
()
def
replay_prepare
(
self
,
forward_batch
:
ForwardBatch
):
def
replay_prepare
(
self
,
forward_batch
:
ForwardBatch
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
):
self
.
recapture_if_needed
(
forward_batch
)
raw_bs
=
forward_batch
.
batch_size
...
...
@@ -519,6 +558,11 @@ class CudaGraphRunner:
self
.
seq_lens_cpu
.
fill_
(
1
)
self
.
seq_lens_cpu
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens_cpu
)
if
pp_proxy_tensors
:
for
key
in
self
.
pp_proxy_tensors
.
keys
():
dim
=
pp_proxy_tensors
[
key
].
shape
[
0
]
self
.
pp_proxy_tensors
[
key
][:
dim
].
copy_
(
pp_proxy_tensors
[
key
])
if
self
.
is_encoder_decoder
:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
if
forward_batch
.
mrope_positions
is
not
None
:
...
...
@@ -547,10 +591,13 @@ class CudaGraphRunner:
self
.
bs
=
bs
def
replay
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
)
->
LogitsProcessorOutput
:
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
LogitsProcessorOutput
,
PPProxyTensors
]:
if
not
skip_attn_backend_init
:
self
.
replay_prepare
(
forward_batch
)
self
.
replay_prepare
(
forward_batch
,
pp_proxy_tensors
)
else
:
# In speculative decoding, these two fields are still needed.
self
.
input_ids
[:
self
.
raw_num_token
].
copy_
(
forward_batch
.
input_ids
)
...
...
@@ -558,17 +605,19 @@ class CudaGraphRunner:
# Replay
self
.
graphs
[
self
.
bs
].
replay
()
next_token_logits
,
hidden_states
=
self
.
output_buffers
[
self
.
bs
]
logits_output
=
LogitsProcessorOutput
(
next_token_logits
=
next_token_logits
[:
self
.
raw_num_token
],
hidden_states
=
(
hidden_states
[:
self
.
raw_num_token
]
if
hidden_states
is
not
None
else
None
),
)
return
logits_output
output
=
self
.
output_buffers
[
self
.
bs
]
if
isinstance
(
output
,
LogitsProcessorOutput
):
return
LogitsProcessorOutput
(
next_token_logits
=
output
.
next_token_logits
[:
self
.
raw_num_token
],
hidden_states
=
(
output
.
hidden_states
[:
self
.
raw_num_token
]
if
output
.
hidden_states
is
not
None
else
None
),
)
else
:
assert
isinstance
(
output
,
PPProxyTensors
)
return
PPProxyTensors
({
k
:
v
[:
self
.
bs
]
for
k
,
v
in
output
.
tensors
.
items
()})
def
get_spec_info
(
self
,
num_tokens
:
int
):
spec_info
=
None
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
11383cec
...
...
@@ -31,7 +31,7 @@ from __future__ import annotations
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Union
import
torch
import
triton
...
...
@@ -585,6 +585,36 @@ class ForwardBatch:
self
.
prepare_chunked_kv_indices
(
device
)
class
PPProxyTensors
:
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
tensors
:
Dict
[
str
,
torch
.
Tensor
]
def
__init__
(
self
,
tensors
):
# manually define this function, so that
# Dynamo knows `IntermediateTensors()` comes from this file.
# Otherwise, dataclass will generate this function by evaluating
# a string, and we will lose the information about the source file.
self
.
tensors
=
tensors
def
__getitem__
(
self
,
key
:
Union
[
str
,
slice
]):
if
isinstance
(
key
,
str
):
return
self
.
tensors
[
key
]
elif
isinstance
(
key
,
slice
):
return
self
.
__class__
({
k
:
v
[
key
]
for
k
,
v
in
self
.
tensors
.
items
()})
def
__setitem__
(
self
,
key
:
str
,
value
:
torch
.
Tensor
):
self
.
tensors
[
key
]
=
value
def
__len__
(
self
):
return
len
(
self
.
tensors
)
def
__eq__
(
self
,
other
:
object
):
return
isinstance
(
other
,
self
.
__class__
)
and
self
def
__repr__
(
self
)
->
str
:
return
f
"PPProxyTensors(tensors=
{
self
.
tensors
}
)"
def
compute_position_triton
(
extend_prefix_lens
:
torch
.
Tensor
,
extend_seq_lens
:
torch
.
Tensor
,
extend_seq_lens_sum
):
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
11383cec
...
...
@@ -13,8 +13,10 @@
# ==============================================================================
"""ModelRunner runs the forward passes of the models."""
import
collections
import
datetime
import
gc
import
inspect
import
json
import
logging
import
os
...
...
@@ -59,7 +61,7 @@ from sglang.srt.mem_cache.memory_pool import (
)
from
sglang.srt.mem_cache.paged_allocator
import
PagedTokenToKVPoolAllocator
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.model_loader.loader
import
(
DefaultModelLoader
,
...
...
@@ -111,6 +113,8 @@ class ModelRunner:
gpu_id
:
int
,
tp_rank
:
int
,
tp_size
:
int
,
pp_rank
:
int
,
pp_size
:
int
,
nccl_port
:
int
,
server_args
:
ServerArgs
,
is_draft_worker
:
bool
=
False
,
...
...
@@ -124,6 +128,8 @@ class ModelRunner:
self
.
gpu_id
=
gpu_id
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
tp_size
self
.
pp_rank
=
pp_rank
self
.
pp_size
=
pp_size
self
.
dist_port
=
nccl_port
self
.
server_args
=
server_args
self
.
is_draft_worker
=
is_draft_worker
...
...
@@ -149,24 +155,24 @@ class ModelRunner:
global_server_args_dict
.
update
(
{
"attention_backend"
:
server_args
.
attention_backend
,
"sampling_backend"
:
server_args
.
sampling_backend
,
"triton_attention_reduce_in_fp32"
:
server_args
.
triton_attention_reduce_in_fp32
,
"torchao_config"
:
server_args
.
torchao_config
,
"debug_tensor_dump_inject"
:
server_args
.
debug_tensor_dump_inject
,
"debug_tensor_dump_output_folder"
:
server_args
.
debug_tensor_dump_output_folder
,
"deepep_mode"
:
server_args
.
deepep_mode
,
"device"
:
server_args
.
device
,
"disable_chunked_prefix_cache"
:
server_args
.
disable_chunked_prefix_cache
,
"disable_radix_cache"
:
server_args
.
disable_radix_cache
,
"enable_nan_detection"
:
server_args
.
enable_nan_detection
,
"enable_dp_attention"
:
server_args
.
enable_dp_attention
,
"enable_ep_moe"
:
server_args
.
enable_ep_moe
,
"enable_deepep_moe"
:
server_args
.
enable_deepep_moe
,
"deepep_mode"
:
server_args
.
deepep_mode
,
"device"
:
server_args
.
device
,
"speculative_accept_threshold_single"
:
server_args
.
speculative_accept_threshold_single
,
"speculative_accept_threshold_acc"
:
server_args
.
speculative_accept_threshold_acc
,
"disable_radix_cache"
:
server_args
.
disable_radix_cache
,
"flashinfer_mla_disable_ragged"
:
server_args
.
flashinfer_mla_disable_ragged
,
"moe_dense_tp_size"
:
server_args
.
moe_dense_tp_size
,
"debug_tensor_dump_output_folder"
:
server_args
.
debug_tensor_dump_output_folder
,
"debug_tensor_dump_inject"
:
server_args
.
debug_tensor_dump_inject
,
"n_share_experts_fusion"
:
server_args
.
n_share_experts_fusion
,
"disable_chunked_prefix_cache"
:
server_args
.
disable_chunked_prefix_cache
,
"triton_attention_reduce_in_fp32"
:
server_args
.
triton_attention_reduce_in_fp32
,
"torchao_config"
:
server_args
.
torchao_config
,
"sampling_backend"
:
server_args
.
sampling_backend
,
"speculative_accept_threshold_single"
:
server_args
.
speculative_accept_threshold_single
,
"speculative_accept_threshold_acc"
:
server_args
.
speculative_accept_threshold_acc
,
"use_mla_backend"
:
self
.
use_mla_backend
,
}
)
...
...
@@ -184,6 +190,11 @@ class ModelRunner:
# If it is a draft model, tp_group can be different
self
.
initialize
(
min_per_gpu_memory
)
# temporary cached values
self
.
support_pp
=
(
"pp_proxy_tensors"
in
inspect
.
signature
(
self
.
model
.
forward
).
parameters
)
def
initialize
(
self
,
min_per_gpu_memory
:
float
):
server_args
=
self
.
server_args
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
...
...
@@ -194,6 +205,12 @@ class ModelRunner:
self
.
sampler
=
Sampler
()
self
.
load_model
()
self
.
start_layer
=
getattr
(
self
.
model
,
"start_layer"
,
0
)
self
.
end_layer
=
getattr
(
self
.
model
,
"end_layer"
,
self
.
model_config
.
num_hidden_layers
)
self
.
num_effective_layers
=
self
.
end_layer
-
self
.
start_layer
# Apply torchao quantization
torchao_applied
=
getattr
(
self
.
model
,
"torchao_applied"
,
False
)
# In layered loading, torchao may have been applied
...
...
@@ -360,18 +377,22 @@ class ModelRunner:
# Only initialize the distributed environment on the target model worker.
init_distributed_environment
(
backend
=
backend
,
world_size
=
self
.
tp_size
,
rank
=
self
.
tp_rank
,
world_size
=
self
.
tp_size
*
self
.
pp_size
,
rank
=
self
.
tp_size
*
self
.
pp_rank
+
self
.
tp_rank
,
local_rank
=
self
.
gpu_id
,
distributed_init_method
=
dist_init_method
,
timeout
=
self
.
server_args
.
dist_timeout
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
,
pipeline_model_parallel_size
=
self
.
pp_size
,
)
initialize_dp_attention
(
enable_dp_attention
=
self
.
server_args
.
enable_dp_attention
,
tp_rank
=
self
.
tp_rank
,
tp_size
=
self
.
tp_size
,
dp_size
=
self
.
server_args
.
dp_size
,
pp_size
=
self
.
server_args
.
pp_size
,
)
min_per_gpu_memory
=
get_available_gpu_memory
(
...
...
@@ -698,6 +719,8 @@ class ModelRunner:
if
not
self
.
is_draft_worker
else
self
.
model_config
.
hf_config
.
num_nextn_predict_layers
)
# FIXME: pipeline parallelism is not compatible with mla backend
assert
self
.
pp_size
==
1
cell_size
=
(
(
self
.
model_config
.
kv_lora_rank
+
self
.
model_config
.
qk_rope_head_dim
)
*
num_layers
...
...
@@ -707,7 +730,7 @@ class ModelRunner:
cell_size
=
(
self
.
model_config
.
get_num_kv_heads
(
get_attention_tp_size
())
*
self
.
model_config
.
head_dim
*
self
.
model_config
.
num_hidden
_layers
*
self
.
num_effective
_layers
*
2
*
torch
.
_utils
.
_element_size
(
self
.
kv_cache_dtype
)
)
...
...
@@ -819,9 +842,11 @@ class ModelRunner:
self
.
model_config
.
num_hidden_layers
if
not
self
.
is_draft_worker
else
self
.
model_config
.
hf_config
.
num_nextn_predict_layers
),
),
# PP is not compatible with mla backend
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
start_layer
=
self
.
start_layer
,
end_layer
=
self
.
end_layer
,
)
elif
self
.
server_args
.
enable_double_sparsity
:
self
.
token_to_kv_pool
=
DoubleSparseTokenToKVPool
(
...
...
@@ -830,10 +855,12 @@ class ModelRunner:
dtype
=
self
.
kv_cache_dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
get_attention_tp_size
()),
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden
_layers
,
layer_num
=
self
.
num_effective
_layers
,
device
=
self
.
device
,
heavy_channel_num
=
self
.
server_args
.
ds_heavy_channel_num
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
start_layer
=
self
.
start_layer
,
end_layer
=
self
.
end_layer
,
)
else
:
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
...
...
@@ -842,9 +869,11 @@ class ModelRunner:
dtype
=
self
.
kv_cache_dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
get_attention_tp_size
()),
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden
_layers
,
layer_num
=
self
.
num_effective
_layers
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
start_layer
=
self
.
start_layer
,
end_layer
=
self
.
end_layer
,
)
if
self
.
token_to_kv_pool_allocator
is
None
:
...
...
@@ -957,7 +986,7 @@ class ModelRunner:
with
open
(
self
.
server_args
.
ds_channel_config_path
,
"r"
)
as
f
:
channel_config
=
json
.
load
(
f
)
for
i
in
range
(
self
.
model_config
.
num_hidd
en_layer
s
):
for
i
in
range
(
self
.
start_layer
,
self
.
en
d
_layer
):
key
=
"model.layers."
+
str
(
i
)
+
".self_attn"
+
selected_channel
self
.
sorted_channels
.
append
(
torch
.
tensor
(
channel_config
[
key
])[
...
...
@@ -997,64 +1026,82 @@ class ModelRunner:
device_mesh
=
torch
.
distributed
.
init_device_mesh
(
self
.
device
,
(
self
.
tp_size
,))
tensor_parallel
(
self
.
model
,
device_mesh
)
def
forward_decode
(
self
,
forward_batch
:
ForwardBatch
):
def
forward_decode
(
self
,
forward_batch
:
ForwardBatch
,
pp_proxy_tensors
=
None
)
->
LogitsProcessorOutput
:
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
# FIXME: add pp_proxy_tensors arg to all models
kwargs
=
{}
if
self
.
support_pp
:
kwargs
[
"pp_proxy_tensors"
]
=
pp_proxy_tensors
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
,
**
kwargs
)
def
forward_extend
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
):
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
,
pp_proxy_tensors
=
None
,
)
->
LogitsProcessorOutput
:
if
not
skip_attn_backend_init
:
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
if
self
.
is_generation
:
if
forward_batch
.
input_embeds
is
None
:
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
else
:
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
,
input_embeds
=
forward_batch
.
input_embeds
.
bfloat16
(),
)
else
:
# Only embedding models have get_embedding parameter
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
,
get_embedding
=
True
,
)
kwargs
=
{}
if
self
.
support_pp
:
kwargs
[
"pp_proxy_tensors"
]
=
pp_proxy_tensors
if
forward_batch
.
input_embeds
is
not
None
:
kwargs
[
"input_embeds"
]
=
forward_batch
.
input_embeds
.
bfloat16
()
if
not
self
.
is_generation
:
kwargs
[
"get_embedding"
]
=
True
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
,
**
kwargs
,
)
def
forward_idle
(
self
,
forward_batch
:
ForwardBatch
):
def
forward_idle
(
self
,
forward_batch
:
ForwardBatch
,
pp_proxy_tensors
=
None
)
->
LogitsProcessorOutput
:
kwargs
=
{}
if
self
.
support_pp
:
kwargs
[
"pp_proxy_tensors"
]
=
pp_proxy_tensors
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
,
**
kwargs
,
)
def
forward
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
)
->
LogitsProcessorOutput
:
if
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
LogitsProcessorOutput
,
PPProxyTensors
]:
can_run_cuda_graph
=
bool
(
forward_batch
.
forward_mode
.
is_cuda_graph
()
and
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
)
):
)
if
can_run_cuda_graph
:
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
if
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
forward_batch
)
return
self
.
forward_decode
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
elif
forward_batch
.
forward_mode
.
is_extend
():
return
self
.
forward_extend
(
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
forward_batch
,
skip_attn_backend_init
=
skip_attn_backend_init
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
elif
forward_batch
.
forward_mode
.
is_idle
():
return
self
.
forward_idle
(
forward_batch
)
return
self
.
forward_idle
(
forward_batch
,
pp_proxy_tensors
=
pp_proxy_tensors
)
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_batch
.
forward_mode
}
"
)
...
...
python/sglang/srt/models/llama.py
View file @
11383cec
...
...
@@ -17,13 +17,14 @@
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import
logging
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
sglang.srt.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
...
...
@@ -39,11 +40,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
kv_cache_scales_loader
,
...
...
@@ -275,21 +277,31 @@ class LlamaModel(nn.Module):
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
self
.
layers
=
make_layers
(
self
.
pp_group
=
get_pp_group
()
if
self
.
pp_group
.
is_first_rank
:
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
=
make_layers
(
config
.
num_hidden_layers
,
lambda
idx
,
prefix
:
LlamaDecoderLayer
(
config
=
config
,
layer_id
=
idx
,
quant_config
=
quant_config
,
prefix
=
prefix
config
=
config
,
quant_config
=
quant_config
,
layer_id
=
idx
,
prefix
=
prefix
),
pp_rank
=
self
.
pp_group
.
rank_in_group
,
pp_size
=
self
.
pp_group
.
world_size
,
prefix
=
"model.layers"
,
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
if
self
.
pp_group
.
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
self
.
layers_to_capture
=
[]
def
forward
(
...
...
@@ -298,14 +310,23 @@ class LlamaModel(nn.Module):
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
PPProxyTensors
]:
if
self
.
pp_group
.
is_first_rank
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
residual
=
None
else
:
hidden_states
=
input_embeds
residual
=
None
assert
pp_proxy_tensors
is
not
None
# FIXME(@ying): reduce the number of proxy tensors by not fusing layer norms
hidden_states
=
pp_proxy_tensors
[
"hidden_states"
]
residual
=
pp_proxy_tensors
[
"residual"
]
deferred_norm
=
None
aux_hidden_states
=
[]
for
i
in
range
(
len
(
self
.
layer
s
)
):
for
i
in
range
(
self
.
start_layer
,
self
.
end_
layer
):
if
i
in
self
.
layers_to_capture
:
aux_hidden_states
.
append
(
hidden_states
+
residual
)
layer
=
self
.
layers
[
i
]
...
...
@@ -315,7 +336,16 @@ class LlamaModel(nn.Module):
forward_batch
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
not
self
.
pp_group
.
is_last_rank
:
return
PPProxyTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
,
}
)
else
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
len
(
aux_hidden_states
)
==
0
:
return
hidden_states
...
...
@@ -376,6 +406,7 @@ class LlamaForCausalLM(nn.Module):
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
self
.
_init_model
(
config
,
quant_config
,
add_prefix
(
"model"
,
prefix
))
...
...
@@ -419,23 +450,41 @@ class LlamaForCausalLM(nn.Module):
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
LogitsProcessorOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
aux_hidden_states
=
None
if
self
.
capture_aux_hidden_states
:
hidden_states
,
aux_hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
,
aux_hidden_states
=
hidden_states
if
self
.
pp_group
.
is_last_rank
:
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
,
aux_hidden_states
,
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
else
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
hidden_states
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
,
aux_hidden_states
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
@
property
def
start_layer
(
self
):
return
self
.
model
.
start_layer
@
property
def
end_layer
(
self
):
return
self
.
model
.
end_layer
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
model
.
embed_tokens
...
...
@@ -491,6 +540,16 @@ class LlamaForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
layer_id
=
get_layer_id
(
name
)
if
(
layer_id
is
not
None
and
hasattr
(
self
.
model
,
"start_layer"
)
and
(
layer_id
<
self
.
model
.
start_layer
or
layer_id
>=
self
.
model
.
end_layer
)
):
continue
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
...
...
@@ -637,6 +696,9 @@ class LlamaForCausalLM(nn.Module):
self
.
model
.
load_kv_cache_scales
(
quantization_param_path
)
def
set_eagle3_layers_to_capture
(
self
):
if
not
self
.
pp_group
.
is_last_rank
:
return
self
.
capture_aux_hidden_states
=
True
num_layers
=
self
.
config
.
num_hidden_layers
self
.
model
.
layers_to_capture
=
[
2
,
num_layers
//
2
,
num_layers
-
3
]
...
...
python/sglang/srt/models/llama4.py
View file @
11383cec
...
...
@@ -46,7 +46,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.models.llama
import
LlamaForCausalLM
,
LlamaMLP
from
sglang.srt.utils
import
add_prefix
,
fast_topk
,
get_compiler_backend
,
make_layers
...
...
@@ -431,6 +431,7 @@ class Llama4Model(nn.Module):
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
...
...
python/sglang/srt/models/llama_eagle.py
View file @
11383cec
...
...
@@ -25,13 +25,14 @@ import torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
sglang.srt.distributed
import
get_pp_group
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.models.llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
...
...
@@ -86,6 +87,7 @@ class LlamaModel(nn.Module):
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
...
...
@@ -118,6 +120,7 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
nn
.
Module
.
__init__
(
self
)
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
pp_group
=
get_pp_group
()
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
...
...
python/sglang/srt/models/llama_eagle3.py
View file @
11383cec
...
...
@@ -25,6 +25,7 @@ import torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
sglang.srt.distributed
import
get_pp_group
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
QKVParallelLinear
,
RowParallelLinear
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
...
...
@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.models.llama
import
LlamaAttention
,
LlamaDecoderLayer
,
LlamaForCausalLM
...
...
@@ -118,6 +119,7 @@ class LlamaModel(nn.Module):
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
embeds
=
self
.
embed_tokens
(
input_ids
)
...
...
@@ -155,6 +157,7 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
nn
.
Module
.
__init__
(
self
)
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
pp_group
=
get_pp_group
()
if
self
.
config
.
num_hidden_layers
!=
1
:
raise
ValueError
(
"EAGLE3 currently only supports 1 layer"
)
...
...
python/sglang/srt/server_args.py
View file @
11383cec
...
...
@@ -78,6 +78,8 @@ class ServerArgs:
# Other runtime options
tp_size
:
int
=
1
pp_size
:
int
=
1
max_micro_batch_size
:
Optional
[
int
]
=
None
stream_interval
:
int
=
1
stream_output
:
bool
=
False
random_seed
:
Optional
[
int
]
=
None
...
...
@@ -222,14 +224,18 @@ class ServerArgs:
# Set mem fraction static, which depends on the tensor parallelism size
if
self
.
mem_fraction_static
is
None
:
if
self
.
tp_size
>=
16
:
self
.
mem_fraction_static
=
0.79
elif
self
.
tp_size
>=
8
:
self
.
mem_fraction_static
=
0.81
elif
self
.
tp_size
>=
4
:
self
.
mem_fraction_static
=
0.85
elif
self
.
tp_size
>=
2
:
self
.
mem_fraction_static
=
0.87
parallel_size
=
self
.
tp_size
*
self
.
pp_size
if
gpu_mem
<=
81920
:
if
parallel_size
>=
16
:
self
.
mem_fraction_static
=
0.79
elif
parallel_size
>=
8
:
self
.
mem_fraction_static
=
0.81
elif
parallel_size
>=
4
:
self
.
mem_fraction_static
=
0.85
elif
parallel_size
>=
2
:
self
.
mem_fraction_static
=
0.87
else
:
self
.
mem_fraction_static
=
0.88
else
:
self
.
mem_fraction_static
=
0.88
if
gpu_mem
>
96
*
1024
:
...
...
@@ -244,6 +250,8 @@ class ServerArgs:
if
self
.
chunked_prefill_size
is
None
:
if
gpu_mem
is
not
None
and
gpu_mem
<
25_000
:
self
.
chunked_prefill_size
=
2048
elif
self
.
disaggregation_mode
!=
"null"
:
self
.
chunked_prefill_size
=
16384
else
:
self
.
chunked_prefill_size
=
8192
assert
self
.
chunked_prefill_size
%
self
.
page_size
==
0
...
...
@@ -643,6 +651,19 @@ class ServerArgs:
default
=
ServerArgs
.
tp_size
,
help
=
"The tensor parallelism size."
,
)
parser
.
add_argument
(
"--pipeline-parallel-size"
,
"--pp-size"
,
type
=
int
,
default
=
ServerArgs
.
pp_size
,
help
=
"The pipeline parallelism size."
,
)
parser
.
add_argument
(
"--max-micro-batch-size"
,
type
=
int
,
default
=
ServerArgs
.
max_micro_batch_size
,
help
=
"The maximum micro batch size in pipeline parallelism."
,
)
parser
.
add_argument
(
"--stream-interval"
,
type
=
int
,
...
...
@@ -1232,6 +1253,7 @@ class ServerArgs:
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
args
.
tp_size
=
args
.
tensor_parallel_size
args
.
pp_size
=
args
.
pipeline_parallel_size
args
.
dp_size
=
args
.
data_parallel_size
args
.
ep_size
=
args
.
expert_parallel_size
attrs
=
[
attr
.
name
for
attr
in
dataclasses
.
fields
(
cls
)]
...
...
@@ -1245,8 +1267,19 @@ class ServerArgs:
def
check_server_args
(
self
):
assert
(
self
.
tp_size
%
self
.
nnodes
==
0
),
"tp_size must be divisible by number of nodes"
self
.
tp_size
*
self
.
pp_size
)
%
self
.
nnodes
==
0
,
"tp_size must be divisible by number of nodes"
# FIXME pp constraints
if
self
.
pp_size
>
1
:
logger
.
warning
(
f
"Turn off overlap scheule for pipeline parallelism."
)
self
.
disable_overlap_schedule
=
True
assert
(
self
.
disable_overlap_schedule
and
self
.
speculative_algorithm
is
None
and
not
self
.
enable_mixed_chunk
),
"Pipeline parallelism is not compatible with overlap schedule, speculative decoding, mixed chunked prefill."
assert
not
(
self
.
dp_size
>
1
and
self
.
nnodes
!=
1
and
not
self
.
enable_dp_attention
),
"multi-node data parallel is not supported unless dp attention!"
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
11383cec
...
...
@@ -106,11 +106,12 @@ class EAGLEWorker(TpModelWorker):
# Init draft worker
with
empty_context
():
super
().
__init__
(
server_args
=
server_args
,
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
server_args
=
server_args
,
nccl_port
=
nccl_port
,
pp_rank
=
0
,
# FIXME
dp_rank
=
dp_rank
,
nccl_port
=
nccl_port
,
is_draft_worker
=
True
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment