Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ad20b795
Unverified
Commit
ad20b795
authored
Jan 02, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 02, 2025
Browse files
Eagle speculative decoding part 3: small modifications to the general scheduler (#2709)
Co-authored-by:
kavioyu
<
kavioyu@tencent.com
>
parent
9183c23e
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
224 additions
and
69 deletions
+224
-69
benchmark/deepseek_v3/README.md
benchmark/deepseek_v3/README.md
+2
-2
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+2
-0
python/sglang/srt/layers/attention/__init__.py
python/sglang/srt/layers/attention/__init__.py
+1
-1
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+5
-3
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+1
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+7
-4
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+48
-6
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+7
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+65
-21
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+22
-3
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+47
-27
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+12
-0
python/sglang/srt/speculative/spec_info.py
python/sglang/srt/speculative/spec_info.py
+5
-0
No files found.
benchmark/deepseek_v3/README.md
View file @
ad20b795
...
...
@@ -61,10 +61,10 @@ For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is `
```
bash
# node 1
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3
--tp
16
--
nccl
-init
10.0.0.1:5000
--nnodes
2
--node-rank
0
--trust-remote-code
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3
--tp
16
--
dist
-init
-addr
10.0.0.1:5000
--nnodes
2
--node-rank
0
--trust-remote-code
# node 2
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3
--tp
16
--
nccl
-init
10.0.0.1:5000
--nnodes
2
--node-rank
1
--trust-remote-code
python
-m
sglang.launch_server
--model-path
deepseek-ai/DeepSeek-V3
--tp
16
--
dist
-init
-addr
10.0.0.1:5000
--nnodes
2
--node-rank
1
--trust-remote-code
```
If you have two H100 nodes, the usage is similar to the aforementioned H20.
...
...
python/sglang/bench_one_batch.py
View file @
ad20b795
...
...
@@ -63,6 +63,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server
import
_set_envs_and_config
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
configure_logger
,
kill_process_tree
,
suppress_other_loggers
...
...
@@ -214,6 +215,7 @@ def extend(reqs, model_runner):
tree_cache
=
None
,
model_config
=
model_runner
.
model_config
,
enable_overlap
=
False
,
spec_algorithm
=
SpeculativeAlgorithm
.
NONE
,
)
batch
.
prepare_for_extend
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
...
python/sglang/srt/layers/attention/__init__.py
View file @
ad20b795
...
...
@@ -26,7 +26,7 @@ class AttentionBackend(ABC):
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_token
:
int
,
num_token
s
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
ad20b795
...
...
@@ -227,7 +227,7 @@ class FlashInferAttnBackend(AttentionBackend):
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_token
:
int
,
num_token
s
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
...
...
@@ -243,9 +243,11 @@ class FlashInferAttnBackend(AttentionBackend):
"NHD"
,
use_cuda_graph
=
True
,
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
kv_indptr
[
i
][:
num_token
+
1
],
paged_kv_indptr_buffer
=
self
.
kv_indptr
[
i
][:
num_token
s
+
1
],
paged_kv_indices_buffer
=
self
.
cuda_graph_kv_indices
[
i
],
paged_kv_last_page_len_buffer
=
self
.
kv_last_page_len
[:
num_token
],
paged_kv_last_page_len_buffer
=
self
.
kv_last_page_len
[
:
num_tokens
],
)
)
seq_lens_sum
=
seq_lens
.
sum
().
item
()
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
ad20b795
...
...
@@ -81,7 +81,7 @@ class TritonAttnBackend(AttentionBackend):
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_token
:
int
,
num_token
s
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
ad20b795
...
...
@@ -575,8 +575,8 @@ class ScheduleBatch:
device
:
str
=
"cuda"
# Speculative decoding
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_info
:
Optional
[
SpecInfo
]
=
None
spec_algorithm
:
Optional
[
SpeculativeAlgorithm
]
=
None
@
classmethod
def
init_new
(
...
...
@@ -587,7 +587,7 @@ class ScheduleBatch:
tree_cache
:
BasePrefixCache
,
model_config
:
ModelConfig
,
enable_overlap
:
bool
,
spec
ulative
_algorithm
:
Optional
[
SpeculativeAlgorithm
]
=
None
,
spec_algorithm
:
SpeculativeAlgorithm
,
):
return
cls
(
reqs
=
reqs
,
...
...
@@ -600,7 +600,7 @@ class ScheduleBatch:
has_stream
=
any
(
req
.
stream
for
req
in
reqs
),
has_grammar
=
any
(
req
.
grammar
for
req
in
reqs
),
device
=
req_to_token_pool
.
device
,
spec_algorithm
=
spec
ulative
_algorithm
,
spec_algorithm
=
spec_algorithm
,
)
def
batch_size
(
self
):
...
...
@@ -1010,6 +1010,8 @@ class ScheduleBatch:
def
prepare_for_decode
(
self
):
self
.
forward_mode
=
ForwardMode
.
DECODE
if
self
.
spec_algorithm
.
is_eagle
():
return
self
.
input_ids
=
self
.
output_ids
self
.
output_ids
=
None
...
...
@@ -1172,6 +1174,7 @@ class ScheduleBatch:
out_cache_loc
=
self
.
out_cache_loc
,
return_logprob
=
self
.
return_logprob
,
decoding_reqs
=
self
.
decoding_reqs
,
spec_algorithm
=
self
.
spec_algorithm
,
)
def
__str__
(
self
):
...
...
@@ -1232,8 +1235,8 @@ class ModelWorkerBatch:
input_embeds
:
Optional
[
torch
.
tensor
]
=
None
# Speculative decoding
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_info
:
Optional
[
SpecInfo
]
=
None
spec_algorithm
:
Optional
[
SpeculativeAlgorithm
]
=
None
@
triton
.
jit
...
...
python/sglang/srt/managers/scheduler.py
View file @
ad20b795
...
...
@@ -76,6 +76,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
SchedulerStats
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
(
broadcast_pyobj
,
configure_logger
,
...
...
@@ -116,6 +117,14 @@ class Scheduler:
self
.
enable_overlap
=
not
server_args
.
disable_overlap_schedule
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
self
.
enable_metrics
=
server_args
.
enable_metrics
self
.
spec_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
)
self
.
decode_mem_cache_buf_multiplier
=
(
self
.
server_args
.
speculative_num_draft_tokens
if
not
self
.
spec_algorithm
.
is_none
()
else
1
)
# Init inter-process communication
context
=
zmq
.
Context
(
2
)
...
...
@@ -199,6 +208,21 @@ class Scheduler:
nccl_port
=
port_args
.
nccl_port
,
)
# Launch worker for speculative decoding if need
if
self
.
spec_algorithm
.
is_eagle
():
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
self
.
draft_worker
=
EAGLEWorker
(
gpu_id
=
gpu_id
,
tp_rank
=
tp_rank
,
server_args
=
server_args
,
nccl_port
=
port_args
.
nccl_port
,
target_worker
=
self
.
tp_worker
,
dp_rank
=
dp_rank
,
)
else
:
self
.
draft_worker
=
None
# Get token and memory info from the model worker
(
self
.
max_total_num_tokens
,
...
...
@@ -855,6 +879,7 @@ class Scheduler:
self
.
tree_cache
,
self
.
model_config
,
self
.
enable_overlap
,
self
.
spec_algorithm
,
)
new_batch
.
prepare_for_extend
()
...
...
@@ -888,11 +913,15 @@ class Scheduler:
return
None
# Check if decode out of memory
if
not
batch
.
check_decode_mem
()
or
(
test_retract
and
batch
.
batch_size
()
>
10
):
if
not
batch
.
check_decode_mem
(
self
.
decode_mem_cache_buf_multiplier
)
or
(
test_retract
and
batch
.
batch_size
()
>
10
):
old_ratio
=
self
.
new_token_ratio
retracted_reqs
,
new_token_ratio
=
batch
.
retract_decode
()
self
.
new_token_ratio
=
new_token_ratio
if
self
.
draft_worker
:
self
.
draft_worker
.
finish_request
(
retracted_reqs
)
logger
.
info
(
"Decode out of memory happened. "
...
...
@@ -926,11 +955,17 @@ class Scheduler:
self
.
forward_ct
+=
1
if
self
.
is_generation
:
model_worker_batch
=
batch
.
get_model_worker_batch
()
if
batch
.
forward_mode
.
is_decode
()
or
batch
.
extend_num_tokens
!=
0
:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
if
self
.
spec_algorithm
.
is_none
():
model_worker_batch
=
batch
.
get_model_worker_batch
()
logits_output
,
next_token_ids
=
(
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
)
else
:
logits_output
,
next_token_ids
,
model_worker_batch
,
spec_info
=
(
self
.
draft_worker
.
forward_batch_speculative_generation
(
batch
)
)
batch
.
spec_info
=
spec_info
elif
batch
.
forward_mode
.
is_idle
():
model_worker_batch
=
batch
.
get_model_worker_batch
()
self
.
tp_worker
.
forward_batch_idle
(
model_worker_batch
)
...
...
@@ -1077,7 +1112,10 @@ class Scheduler:
self
.
token_to_kv_pool
.
free
(
batch
.
out_cache_loc
[
i
:
i
+
1
])
continue
req
.
output_ids
.
append
(
next_token_id
)
if
batch
.
spec_algorithm
.
is_none
():
# speculative worker will solve the output_ids in speculative decoding
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
if
req
.
finished
():
...
...
@@ -1252,6 +1290,9 @@ class Scheduler:
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
or
(
not
req
.
stream
and
len
(
req
.
output_ids
)
%
50
==
0
)
):
if
self
.
draft_worker
and
req
.
finished
():
self
.
draft_worker
.
finish_request
(
req
)
rids
.
append
(
req
.
rid
)
finished_reasons
.
append
(
req
.
finished_reason
.
to_json
()
if
req
.
finished_reason
else
None
...
...
@@ -1383,6 +1424,7 @@ class Scheduler:
self
.
tree_cache
,
self
.
model_config
,
self
.
enable_overlap
,
self
.
spec_algorithm
,
)
idle_batch
.
prepare_for_idle
()
return
idle_batch
...
...
python/sglang/srt/managers/tp_worker.py
View file @
ad20b795
...
...
@@ -45,13 +45,18 @@ class TpModelWorker:
tp_rank
:
int
,
dp_rank
:
Optional
[
int
],
nccl_port
:
int
,
is_draft_worker
:
bool
=
False
,
):
# Parse args
self
.
tp_rank
=
tp_rank
# Init model and tokenizer
self
.
model_config
=
ModelConfig
(
server_args
.
model_path
,
(
server_args
.
model_path
if
not
is_draft_worker
else
server_args
.
speculative_draft_model_path
),
trust_remote_code
=
server_args
.
trust_remote_code
,
revision
=
server_args
.
revision
,
context_length
=
server_args
.
context_length
,
...
...
@@ -68,6 +73,7 @@ class TpModelWorker:
tp_size
=
server_args
.
tp_size
,
nccl_port
=
nccl_port
,
server_args
=
server_args
,
is_draft_worker
=
is_draft_worker
,
)
if
server_args
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
processor
=
None
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
ad20b795
...
...
@@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch
,
ForwardMode
,
)
from
sglang.srt.utils
import
maybe_torch_compile
,
monkey_patch_vllm_all_gather
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
...
@@ -106,11 +106,6 @@ def set_torch_compile_config():
torch
.
_dynamo
.
config
.
cache_size_limit
=
1024
@
maybe_torch_compile
(
dynamic
=
True
)
def
clamp_position
(
seq_lens
):
return
torch
.
clamp
((
seq_lens
-
1
),
min
=
0
).
to
(
torch
.
int64
)
class
CudaGraphRunner
:
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
...
...
@@ -157,6 +152,17 @@ class CudaGraphRunner:
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
num_tokens_per_bs
=
1
if
model_runner
.
spec_algorithm
.
is_eagle
():
if
self
.
model_runner
.
is_draft_worker
:
self
.
num_tokens_per_bs
=
(
self
.
model_runner
.
server_args
.
speculative_eagle_topk
)
else
:
self
.
capture_forward_mode
=
ForwardMode
.
TARGET_VERIFY
self
.
num_tokens_per_bs
=
(
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
)
self
.
compile_bs
=
(
[
bs
...
...
@@ -192,6 +198,13 @@ class CudaGraphRunner:
self
.
positions
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int32
)
# Speculative_inference
if
model_runner
.
spec_algorithm
.
is_eagle
():
self
.
hidden_states
=
torch
.
zeros
(
(
self
.
max_num_token
,
self
.
model_runner
.
model_config
.
hidden_size
),
dtype
=
self
.
model_runner
.
dtype
,
)
if
self
.
is_encoder_decoder
:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
self
.
encoder_lens
=
torch
.
full
(
...
...
@@ -234,9 +247,6 @@ class CudaGraphRunner:
self
.
model_runner
.
model
.
capture_mode
=
False
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
not
forward_batch
.
forward_mode
.
is_cuda_graph
():
return
False
if
self
.
enable_dp_attention
:
min_num_tokens
,
max_num_tokens
=
min
(
forward_batch
.
global_num_tokens
),
max
(
forward_batch
.
global_num_tokens
...
...
@@ -291,21 +301,18 @@ class CudaGraphRunner:
def
capture_one_batch_size
(
self
,
bs
:
int
,
forward
:
Callable
):
graph
=
torch
.
cuda
.
CUDAGraph
()
stream
=
self
.
stream
num_token
=
bs
*
self
.
num_tokens_per_bs
num_token
s
=
bs
*
self
.
num_tokens_per_bs
# Common inputs
input_ids
=
self
.
input_ids
[:
num_token
]
input_ids
=
self
.
input_ids
[:
num_token
s
]
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
seq_lens
=
self
.
seq_lens
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
num_token
]
positions
=
self
.
positions
[:
num_token
]
out_cache_loc
=
self
.
out_cache_loc
[:
num_tokens
]
positions
=
self
.
positions
[:
num_tokens
]
if
self
.
is_encoder_decoder
:
encoder_lens
=
self
.
encoder_lens
[:
bs
]
else
:
encoder_lens
=
None
seq_lens_sum
=
seq_lens
.
sum
().
item
()
mrope_positions
=
self
.
mrope_positions
[:,
:
bs
]
if
self
.
enable_dp_attention
:
...
...
@@ -325,20 +332,22 @@ class CudaGraphRunner:
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
seq_lens
_
sum
,
seq_lens_sum
=
seq_lens
.
sum
()
,
encoder_lens
=
encoder_lens
,
return_logprob
=
False
,
top_logprobs_nums
=
[
0
]
*
num_token
,
top_logprobs_nums
=
[
0
]
*
bs
,
positions
=
positions
,
global_num_tokens
=
global_num_tokens
,
mrope_positions
=
mrope_positions
,
gathered_buffer
=
gathered_buffer
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_info
=
self
.
get_spec_info
(
num_tokens
,
positions
),
)
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
num_token
,
num_token
s
,
req_pool_indices
,
seq_lens
,
encoder_lens
,
...
...
@@ -394,14 +403,16 @@ class CudaGraphRunner:
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
self
.
seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens
)
self
.
out_cache_loc
[:
raw_num_token
].
copy_
(
forward_batch
.
out_cache_loc
)
positions
=
clamp_position
(
forward_batch
.
seq_lens
)
self
.
positions
[:
raw_num_token
].
copy_
(
positions
)
self
.
positions
[:
raw_num_token
].
copy_
(
forward_batch
.
positions
)
if
self
.
is_encoder_decoder
:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
if
forward_batch
.
mrope_positions
is
not
None
:
self
.
mrope_positions
[:,
:
raw_bs
].
copy_
(
forward_batch
.
mrope_positions
)
if
hasattr
(
forward_batch
.
spec_info
,
"hidden_states"
):
self
.
hidden_states
[:
raw_num_token
]
=
forward_batch
.
spec_info
.
hidden_states
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
...
...
@@ -424,3 +435,36 @@ class CudaGraphRunner:
),
)
return
logits_output
def
get_spec_info
(
self
,
num_tokens
:
int
,
positions
:
torch
.
Tensor
):
spec_info
=
None
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
():
from
sglang.srt.speculative.eagle_utils
import
(
EAGLEDraftInput
,
EagleVerifyInput
,
)
if
self
.
model_runner
.
is_draft_worker
:
spec_info
=
EAGLEDraftInput
()
spec_info
.
hidden_states
=
self
.
hidden_states
[:
num_tokens
]
spec_info
.
positions
=
positions
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
spec_info
.
init
(
self
.
model_runner
.
server_args
)
else
:
spec_info
=
EagleVerifyInput
(
None
,
None
,
None
,
None
,
None
,
None
,
self
.
model_runner
.
server_args
.
speculative_num_draft_tokens
,
)
spec_info
.
custom_mask
=
torch
.
zeros
(
(
num_tokens
*
self
.
model_runner
.
model_config
.
context_len
),
dtype
=
torch
.
bool
,
device
=
"cuda"
,
)
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
return
spec_info
python/sglang/srt/model_executor/forward_batch_info.py
View file @
ad20b795
...
...
@@ -38,6 +38,7 @@ import triton
import
triton.language
as
tl
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.utils
import
maybe_torch_compile
if
TYPE_CHECKING
:
from
sglang.srt.layers.attention
import
AttentionBackend
...
...
@@ -276,10 +277,21 @@ class ForwardBatch:
)
if
ret
.
forward_mode
.
is_idle
():
ret
.
positions
=
torch
.
empty
((
0
,),
device
=
device
)
return
ret
# Override the positions with spec_info
if
(
ret
.
spec_info
is
not
None
and
getattr
(
ret
.
spec_info
,
"positions"
,
None
)
is
not
None
):
ret
.
positions
=
ret
.
spec_info
.
positions
# Init position information
if
not
ret
.
forward_mode
.
is_decode
():
if
ret
.
forward_mode
.
is_decode
():
if
ret
.
positions
is
None
:
ret
.
positions
=
clamp_position
(
batch
.
seq_lens
)
else
:
ret
.
extend_seq_lens
=
torch
.
tensor
(
batch
.
extend_seq_lens
,
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
...
...
@@ -288,13 +300,15 @@ class ForwardBatch:
).
to
(
device
,
non_blocking
=
True
)
if
model_runner
.
server_args
.
attention_backend
!=
"torch_native"
:
ret
.
extend_num_tokens
=
batch
.
extend_num_tokens
ret
.
positions
,
ret
.
extend_start_loc
=
compute_position_triton
(
positions
,
ret
.
extend_start_loc
=
compute_position_triton
(
ret
.
extend_prefix_lens
,
ret
.
extend_seq_lens
,
ret
.
extend_num_tokens
)
else
:
ret
.
positions
,
ret
.
extend_start_loc
=
compute_position_torch
(
positions
,
ret
.
extend_start_loc
=
compute_position_torch
(
ret
.
extend_prefix_lens
,
ret
.
extend_seq_lens
)
if
ret
.
positions
is
None
:
ret
.
positions
=
positions
ret
.
extend_prefix_lens_cpu
=
batch
.
extend_prefix_lens
ret
.
extend_seq_lens_cpu
=
batch
.
extend_seq_lens
ret
.
extend_logprob_start_lens_cpu
=
batch
.
extend_logprob_start_lens
...
...
@@ -383,6 +397,11 @@ def compute_position_torch(
return
positions
.
to
(
torch
.
int64
),
extend_start_loc
@
maybe_torch_compile
(
dynamic
=
True
)
def
clamp_position
(
seq_lens
):
return
torch
.
clamp
((
seq_lens
-
1
),
min
=
0
).
to
(
torch
.
int64
)
class
CaptureHiddenMode
(
IntEnum
):
NULL
=
auto
()
FULL
=
auto
()
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
ad20b795
...
...
@@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import (
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
(
enable_show_time_cost
,
get_available_gpu_memory
,
...
...
@@ -74,6 +75,7 @@ class ModelRunner:
tp_size
:
int
,
nccl_port
:
int
,
server_args
:
ServerArgs
,
is_draft_worker
:
bool
=
False
,
):
# Parse args
self
.
model_config
=
model_config
...
...
@@ -84,8 +86,12 @@ class ModelRunner:
self
.
tp_size
=
tp_size
self
.
dist_port
=
nccl_port
self
.
server_args
=
server_args
self
.
is_draft_worker
=
is_draft_worker
self
.
is_generation
=
model_config
.
is_generation
self
.
is_multimodal
=
model_config
.
is_multimodal
self
.
spec_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
)
# Model-specific adjustment
if
(
...
...
@@ -205,14 +211,18 @@ class ModelRunner:
else
:
dist_init_method
=
f
"tcp://127.0.0.1:
{
self
.
dist_port
}
"
set_custom_all_reduce
(
not
self
.
server_args
.
disable_custom_all_reduce
)
init_distributed_environment
(
backend
=
backend
,
world_size
=
self
.
tp_size
,
rank
=
self
.
tp_rank
,
local_rank
=
self
.
gpu_id
,
distributed_init_method
=
dist_init_method
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
if
not
self
.
is_draft_worker
:
# Only initilzie the distributed environment on the target model worker.
init_distributed_environment
(
backend
=
backend
,
world_size
=
self
.
tp_size
,
rank
=
self
.
tp_rank
,
local_rank
=
self
.
gpu_id
,
distributed_init_method
=
dist_init_method
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
min_per_gpu_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
...
...
@@ -407,7 +417,6 @@ class ModelRunner:
target_dtype
=
(
dtype
if
isinstance
(
dtype
,
torch
.
dtype
)
else
getattr
(
torch
,
dtype
)
)
current_dtype
=
self
.
dtype
if
isinstance
(
self
.
dtype
,
str
)
else
self
.
dtype
assert
(
self
.
_model_update_group
is
not
None
...
...
@@ -506,6 +515,28 @@ class ModelRunner:
)
self
.
max_total_num_tokens
=
self
.
profile_max_num_token
(
total_gpu_memory
)
if
max_num_reqs
is
None
:
max_num_reqs
=
min
(
max
(
int
(
self
.
max_total_num_tokens
/
self
.
model_config
.
context_len
*
512
),
2048
,
),
4096
,
)
if
not
self
.
spec_algorithm
.
is_none
():
if
self
.
is_draft_worker
:
self
.
max_total_num_tokens
=
self
.
server_args
.
draft_runner_cache_size
else
:
self
.
server_args
.
draft_runner_cache_size
=
(
self
.
max_total_num_tokens
+
max_num_reqs
*
self
.
server_args
.
speculative_num_steps
+
100
)
if
max_total_tokens
is
not
None
:
if
max_total_tokens
>
self
.
max_total_num_tokens
:
logging
.
warning
(
...
...
@@ -520,17 +551,6 @@ class ModelRunner:
"Not enough memory. Please try to increase --mem-fraction-static."
)
if
max_num_reqs
is
None
:
max_num_reqs
=
min
(
max
(
int
(
self
.
max_total_num_tokens
/
self
.
model_config
.
context_len
*
512
),
2048
,
),
4096
,
)
self
.
req_to_token_pool
=
ReqToTokenPool
(
size
=
max_num_reqs
+
1
,
max_context_len
=
self
.
model_config
.
context_len
+
4
,
...
...
@@ -650,10 +670,6 @@ class ModelRunner:
tensor_parallel
(
self
.
model
,
device_mesh
)
def
forward_decode
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
):
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
)
forward_batch
.
positions
=
(
forward_batch
.
seq_lens
-
1
).
to
(
torch
.
int64
)
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
...
...
@@ -683,14 +699,18 @@ class ModelRunner:
)
def
forward_idle
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
):
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
)
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
def
forward
(
self
,
forward_batch
:
ForwardBatch
)
->
LogitsProcessorOutput
:
if
(
forward_batch
.
forward_mode
.
is_cuda_graph
()
and
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
)
):
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
)
if
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
forward_batch
)
elif
forward_batch
.
forward_mode
.
is_extend
():
...
...
python/sglang/srt/server_args.py
View file @
ad20b795
...
...
@@ -23,6 +23,7 @@ from typing import List, Optional
import
torch
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
(
get_amdgpu_memory_capacity
,
get_hpu_memory_capacity
,
...
...
@@ -247,6 +248,17 @@ class ServerArgs:
"Overlap scheduler is disabled."
)
# Speculative Decoding
if
self
.
speculative_algorithm
==
"EAGLE"
:
self
.
prefill_only_one_req
=
True
self
.
disable_cuda_graph_padding
=
True
self
.
disable_radix_cache
=
True
self
.
disable_overlap_schedule
=
True
self
.
chunked_prefill_size
=
-
1
logger
.
info
(
"The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding."
)
# GGUF
if
(
self
.
load_format
==
"auto"
or
self
.
load_format
==
"gguf"
...
...
python/sglang/srt/speculative/spec_info.py
View file @
ad20b795
...
...
@@ -2,8 +2,12 @@ from enum import IntEnum, auto
class
SpeculativeAlgorithm
(
IntEnum
):
NONE
=
auto
()
EAGLE
=
auto
()
def
is_none
(
self
):
return
self
==
SpeculativeAlgorithm
.
NONE
def
is_eagle
(
self
):
return
self
==
SpeculativeAlgorithm
.
EAGLE
...
...
@@ -11,6 +15,7 @@ class SpeculativeAlgorithm(IntEnum):
def
from_string
(
name
:
str
):
name_map
=
{
"EAGLE"
:
SpeculativeAlgorithm
.
EAGLE
,
None
:
SpeculativeAlgorithm
.
NONE
,
}
return
name_map
[
name
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment