Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b0524c37
Unverified
Commit
b0524c37
authored
Dec 31, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 31, 2024
Browse files
Eagle speculative decoding part 2: Fix cuda graph + DP attention hanging (#2684)
Co-authored-by:
yukavio
<
kavioyu@gmail.com
>
parent
6c42fa22
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
131 additions
and
58 deletions
+131
-58
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+1
-1
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+4
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+23
-2
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+7
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+1
-3
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+11
-5
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+84
-45
No files found.
.github/workflows/pr-test.yml
View file @
b0524c37
...
...
@@ -92,7 +92,7 @@ jobs:
python3 test_data_parallelism.py
-
name
:
Evaluate MLA accuracy (TP=2)
timeout-minutes
:
2
0
timeout-minutes
:
1
0
run
:
|
cd test/srt
python3 test_mla.py
...
...
python/sglang/srt/layers/logits_processor.py
View file @
b0524c37
...
...
@@ -146,7 +146,10 @@ class LogitsProcessor(nn.Module):
# Compute logits
last_logits
=
self
.
_get_logits
(
last_hidden
,
lm_head
)
if
not
logits_metadata
.
extend_return_logprob
:
if
(
not
logits_metadata
.
extend_return_logprob
or
logits_metadata
.
capture_hidden_mode
.
need_capture
()
):
# Decode mode or extend mode without return_logprob.
return
LogitsProcessorOutput
(
next_token_logits
=
last_logits
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
b0524c37
from
__future__
import
annotations
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -29,7 +31,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
import
dataclasses
import
logging
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -47,6 +49,10 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
if
TYPE_CHECKING
:
from
sglang.srt.speculative.spec_info
import
SpecInfo
,
SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
# Put some global args for easy access
...
...
@@ -565,9 +571,13 @@ class ScheduleBatch:
# Has grammar
has_grammar
:
bool
=
False
#
d
evice
#
D
evice
device
:
str
=
"cuda"
# Speculative decoding
spec_info
:
Optional
[
SpecInfo
]
=
None
spec_algorithm
:
Optional
[
SpeculativeAlgorithm
]
=
None
@
classmethod
def
init_new
(
cls
,
...
...
@@ -577,6 +587,7 @@ class ScheduleBatch:
tree_cache
:
BasePrefixCache
,
model_config
:
ModelConfig
,
enable_overlap
:
bool
,
speculative_algorithm
:
Optional
[
SpeculativeAlgorithm
]
=
None
,
):
return
cls
(
reqs
=
reqs
,
...
...
@@ -589,6 +600,7 @@ class ScheduleBatch:
has_stream
=
any
(
req
.
stream
for
req
in
reqs
),
has_grammar
=
any
(
req
.
grammar
for
req
in
reqs
),
device
=
req_to_token_pool
.
device
,
spec_algorithm
=
speculative_algorithm
,
)
def
batch_size
(
self
):
...
...
@@ -1103,6 +1115,9 @@ class ScheduleBatch:
self
.
has_stream
|=
other
.
has_stream
self
.
has_grammar
|=
other
.
has_grammar
if
self
.
spec_info
:
self
.
spec_info
.
merge_batch
(
other
.
spec_info
)
def
get_model_worker_batch
(
self
):
if
self
.
forward_mode
.
is_decode
()
or
self
.
forward_mode
.
is_idle
():
extend_seq_lens
=
extend_prefix_lens
=
extend_logprob_start_lens
=
None
...
...
@@ -1144,6 +1159,8 @@ class ScheduleBatch:
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
],
sampling_info
=
self
.
sampling_info
,
input_embeds
=
self
.
input_embeds
,
spec_algorithm
=
self
.
spec_algorithm
,
spec_info
=
self
.
spec_info
,
)
def
copy
(
self
):
...
...
@@ -1214,6 +1231,10 @@ class ModelWorkerBatch:
# The input Embeds
input_embeds
:
Optional
[
torch
.
tensor
]
=
None
# Speculative decoding
spec_info
:
Optional
[
SpecInfo
]
=
None
spec_algorithm
:
Optional
[
SpeculativeAlgorithm
]
=
None
@
triton
.
jit
def
write_req_to_token_pool_triton
(
...
...
python/sglang/srt/managers/tp_worker.py
View file @
b0524c37
...
...
@@ -150,12 +150,18 @@ class TpModelWorker:
self
,
model_worker_batch
:
ModelWorkerBatch
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
skip_sample
:
bool
=
False
,
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
if
launch_done
:
launch_done
.
set
()
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
model_worker_batch
)
if
skip_sample
:
next_token_ids
=
None
else
:
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
model_worker_batch
)
return
logits_output
,
next_token_ids
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
b0524c37
...
...
@@ -375,9 +375,7 @@ class CudaGraphRunner:
def
replay
(
self
,
forward_batch
:
ForwardBatch
):
assert
forward_batch
.
out_cache_loc
is
not
None
raw_bs
=
forward_batch
.
batch_size
# In normal decoding case, raw_bs == raw_num_token
# But in speculative decoding, raw_num_token is raw_bs * self.num_tokens_per_bs
raw_num_token
=
forward_batch
.
input_ids
.
numel
()
raw_num_token
=
raw_bs
*
self
.
num_tokens_per_bs
# Pad
if
self
.
enable_dp_attention
:
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
b0524c37
...
...
@@ -96,7 +96,11 @@ class ForwardMode(IntEnum):
return
self
==
ForwardMode
.
DRAFT_EXTEND
def
is_cuda_graph
(
self
):
return
self
==
ForwardMode
.
DECODE
or
self
==
ForwardMode
.
TARGET_VERIFY
return
(
self
==
ForwardMode
.
DECODE
or
self
==
ForwardMode
.
TARGET_VERIFY
or
self
==
ForwardMode
.
IDLE
)
def
is_dummy_first
(
self
):
return
self
==
ForwardMode
.
DUMMY_FIRST
...
...
@@ -161,15 +165,15 @@ class ForwardBatch:
token_to_kv_pool
:
BaseTokenToKVPool
=
None
attn_backend
:
AttentionBackend
=
None
# Speculative decoding
spec_info
:
SpecInfo
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
can_run_dp_cuda_graph
:
bool
=
False
# Speculative decoding
spec_info
:
SpecInfo
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
# For Qwen2-VL
mrope_positions
:
torch
.
Tensor
=
None
...
...
@@ -258,6 +262,8 @@ class ForwardBatch:
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
lora_paths
=
batch
.
lora_paths
,
sampling_info
=
batch
.
sampling_info
,
spec_algorithm
=
batch
.
spec_algorithm
,
spec_info
=
batch
.
spec_info
,
input_embeds
=
batch
.
input_embeds
,
)
...
...
python/sglang/srt/server_args.py
View file @
b0524c37
...
...
@@ -108,14 +108,6 @@ class ServerArgs:
# Model override args in JSON
json_model_override_args
:
str
=
"{}"
# Double Sparsity
enable_double_sparsity
:
bool
=
False
ds_channel_config_path
:
str
=
None
ds_heavy_channel_num
:
int
=
32
ds_heavy_token_num
:
int
=
256
ds_heavy_channel_type
:
str
=
"qk"
ds_sparse_decode_threshold
:
int
=
4096
# LoRA
lora_paths
:
Optional
[
List
[
str
]]
=
None
max_loras_per_batch
:
int
=
8
...
...
@@ -125,6 +117,21 @@ class ServerArgs:
sampling_backend
:
Optional
[
str
]
=
None
grammar_backend
:
Optional
[
str
]
=
"outlines"
# Speculative decoding
speculative_draft_model_path
:
Optional
[
str
]
=
None
speculative_algorithm
:
Optional
[
str
]
=
None
speculative_num_steps
:
int
=
5
speculative_num_draft_tokens
:
int
=
64
speculative_eagle_topk
:
int
=
8
# Double Sparsity
enable_double_sparsity
:
bool
=
False
ds_channel_config_path
:
str
=
None
ds_heavy_channel_num
:
int
=
32
ds_heavy_token_num
:
int
=
256
ds_heavy_channel_type
:
str
=
"qk"
ds_sparse_decode_threshold
:
int
=
4096
# Optimization/debug options
disable_radix_cache
:
bool
=
False
disable_jump_forward
:
bool
=
False
...
...
@@ -602,43 +609,6 @@ class ServerArgs:
default
=
ServerArgs
.
json_model_override_args
,
)
# Double Sparsity
parser
.
add_argument
(
"--enable-double-sparsity"
,
action
=
"store_true"
,
help
=
"Enable double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-channel-config-path"
,
type
=
str
,
default
=
ServerArgs
.
ds_channel_config_path
,
help
=
"The path of the double sparsity channel config"
,
)
parser
.
add_argument
(
"--ds-heavy-channel-num"
,
type
=
int
,
default
=
ServerArgs
.
ds_heavy_channel_num
,
help
=
"The number of heavy channels in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-heavy-token-num"
,
type
=
int
,
default
=
ServerArgs
.
ds_heavy_token_num
,
help
=
"The number of heavy tokens in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-heavy-channel-type"
,
type
=
str
,
default
=
ServerArgs
.
ds_heavy_channel_type
,
help
=
"The type of heavy channels in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-sparse-decode-threshold"
,
type
=
int
,
default
=
ServerArgs
.
ds_sparse_decode_threshold
,
help
=
"The type of heavy channels in double sparsity attention"
,
)
# LoRA
parser
.
add_argument
(
"--lora-paths"
,
...
...
@@ -678,6 +648,75 @@ class ServerArgs:
help
=
"Choose the backend for grammar-guided decoding."
,
)
# Speculative decoding
parser
.
add_argument
(
"--speculative-algorithm"
,
type
=
str
,
choices
=
[
"EAGLE"
],
help
=
"Speculative algorithm."
,
)
parser
.
add_argument
(
"--speculative-draft-model-path"
,
type
=
str
,
help
=
"The path of the draft model weights. This can be a local folder or a Hugging Face repo ID."
,
)
parser
.
add_argument
(
"--speculative-num-steps"
,
type
=
int
,
help
=
"The number of steps sampled from draft model in Speculative Decoding."
,
default
=
ServerArgs
.
speculative_num_steps
,
)
parser
.
add_argument
(
"--speculative-num-draft-tokens"
,
type
=
int
,
help
=
"The number of token sampled from draft model in Speculative Decoding."
,
default
=
ServerArgs
.
speculative_num_draft_tokens
,
)
parser
.
add_argument
(
"--speculative-eagle-topk"
,
type
=
int
,
help
=
"The number of token sampled from draft model in eagle2 each step."
,
choices
=
[
1
,
2
,
4
,
8
],
default
=
ServerArgs
.
speculative_eagle_topk
,
)
# Double Sparsity
parser
.
add_argument
(
"--enable-double-sparsity"
,
action
=
"store_true"
,
help
=
"Enable double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-channel-config-path"
,
type
=
str
,
default
=
ServerArgs
.
ds_channel_config_path
,
help
=
"The path of the double sparsity channel config"
,
)
parser
.
add_argument
(
"--ds-heavy-channel-num"
,
type
=
int
,
default
=
ServerArgs
.
ds_heavy_channel_num
,
help
=
"The number of heavy channels in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-heavy-token-num"
,
type
=
int
,
default
=
ServerArgs
.
ds_heavy_token_num
,
help
=
"The number of heavy tokens in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-heavy-channel-type"
,
type
=
str
,
default
=
ServerArgs
.
ds_heavy_channel_type
,
help
=
"The type of heavy channels in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-sparse-decode-threshold"
,
type
=
int
,
default
=
ServerArgs
.
ds_sparse_decode_threshold
,
help
=
"The type of heavy channels in double sparsity attention"
,
)
# Optimization/debug options
parser
.
add_argument
(
"--disable-radix-cache"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment