Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b0524c37
Unverified
Commit
b0524c37
authored
Dec 31, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 31, 2024
Browse files
Eagle speculative decoding part 2: Fix cuda graph + DP attention hanging (#2684)
Co-authored-by:
yukavio
<
kavioyu@gmail.com
>
parent
6c42fa22
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
131 additions
and
58 deletions
+131
-58
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+1
-1
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+4
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+23
-2
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+7
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+1
-3
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+11
-5
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+84
-45
No files found.
.github/workflows/pr-test.yml
View file @
b0524c37
...
@@ -92,7 +92,7 @@ jobs:
...
@@ -92,7 +92,7 @@ jobs:
python3 test_data_parallelism.py
python3 test_data_parallelism.py
-
name
:
Evaluate MLA accuracy (TP=2)
-
name
:
Evaluate MLA accuracy (TP=2)
timeout-minutes
:
2
0
timeout-minutes
:
1
0
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 test_mla.py
python3 test_mla.py
...
...
python/sglang/srt/layers/logits_processor.py
View file @
b0524c37
...
@@ -146,7 +146,10 @@ class LogitsProcessor(nn.Module):
...
@@ -146,7 +146,10 @@ class LogitsProcessor(nn.Module):
# Compute logits
# Compute logits
last_logits
=
self
.
_get_logits
(
last_hidden
,
lm_head
)
last_logits
=
self
.
_get_logits
(
last_hidden
,
lm_head
)
if
not
logits_metadata
.
extend_return_logprob
:
if
(
not
logits_metadata
.
extend_return_logprob
or
logits_metadata
.
capture_hidden_mode
.
need_capture
()
):
# Decode mode or extend mode without return_logprob.
# Decode mode or extend mode without return_logprob.
return
LogitsProcessorOutput
(
return
LogitsProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logits
=
last_logits
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
b0524c37
from
__future__
import
annotations
# Copyright 2023-2024 SGLang Team
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -29,7 +31,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
...
@@ -29,7 +31,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
import
dataclasses
import
dataclasses
import
logging
import
logging
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -47,6 +49,10 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
...
@@ -47,6 +49,10 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
if
TYPE_CHECKING
:
from
sglang.srt.speculative.spec_info
import
SpecInfo
,
SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
# Put some global args for easy access
# Put some global args for easy access
...
@@ -565,9 +571,13 @@ class ScheduleBatch:
...
@@ -565,9 +571,13 @@ class ScheduleBatch:
# Has grammar
# Has grammar
has_grammar
:
bool
=
False
has_grammar
:
bool
=
False
#
d
evice
#
D
evice
device
:
str
=
"cuda"
device
:
str
=
"cuda"
# Speculative decoding
spec_info
:
Optional
[
SpecInfo
]
=
None
spec_algorithm
:
Optional
[
SpeculativeAlgorithm
]
=
None
@
classmethod
@
classmethod
def
init_new
(
def
init_new
(
cls
,
cls
,
...
@@ -577,6 +587,7 @@ class ScheduleBatch:
...
@@ -577,6 +587,7 @@ class ScheduleBatch:
tree_cache
:
BasePrefixCache
,
tree_cache
:
BasePrefixCache
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
enable_overlap
:
bool
,
enable_overlap
:
bool
,
speculative_algorithm
:
Optional
[
SpeculativeAlgorithm
]
=
None
,
):
):
return
cls
(
return
cls
(
reqs
=
reqs
,
reqs
=
reqs
,
...
@@ -589,6 +600,7 @@ class ScheduleBatch:
...
@@ -589,6 +600,7 @@ class ScheduleBatch:
has_stream
=
any
(
req
.
stream
for
req
in
reqs
),
has_stream
=
any
(
req
.
stream
for
req
in
reqs
),
has_grammar
=
any
(
req
.
grammar
for
req
in
reqs
),
has_grammar
=
any
(
req
.
grammar
for
req
in
reqs
),
device
=
req_to_token_pool
.
device
,
device
=
req_to_token_pool
.
device
,
spec_algorithm
=
speculative_algorithm
,
)
)
def
batch_size
(
self
):
def
batch_size
(
self
):
...
@@ -1103,6 +1115,9 @@ class ScheduleBatch:
...
@@ -1103,6 +1115,9 @@ class ScheduleBatch:
self
.
has_stream
|=
other
.
has_stream
self
.
has_stream
|=
other
.
has_stream
self
.
has_grammar
|=
other
.
has_grammar
self
.
has_grammar
|=
other
.
has_grammar
if
self
.
spec_info
:
self
.
spec_info
.
merge_batch
(
other
.
spec_info
)
def
get_model_worker_batch
(
self
):
def
get_model_worker_batch
(
self
):
if
self
.
forward_mode
.
is_decode
()
or
self
.
forward_mode
.
is_idle
():
if
self
.
forward_mode
.
is_decode
()
or
self
.
forward_mode
.
is_idle
():
extend_seq_lens
=
extend_prefix_lens
=
extend_logprob_start_lens
=
None
extend_seq_lens
=
extend_prefix_lens
=
extend_logprob_start_lens
=
None
...
@@ -1144,6 +1159,8 @@ class ScheduleBatch:
...
@@ -1144,6 +1159,8 @@ class ScheduleBatch:
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
],
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
],
sampling_info
=
self
.
sampling_info
,
sampling_info
=
self
.
sampling_info
,
input_embeds
=
self
.
input_embeds
,
input_embeds
=
self
.
input_embeds
,
spec_algorithm
=
self
.
spec_algorithm
,
spec_info
=
self
.
spec_info
,
)
)
def
copy
(
self
):
def
copy
(
self
):
...
@@ -1214,6 +1231,10 @@ class ModelWorkerBatch:
...
@@ -1214,6 +1231,10 @@ class ModelWorkerBatch:
# The input Embeds
# The input Embeds
input_embeds
:
Optional
[
torch
.
tensor
]
=
None
input_embeds
:
Optional
[
torch
.
tensor
]
=
None
# Speculative decoding
spec_info
:
Optional
[
SpecInfo
]
=
None
spec_algorithm
:
Optional
[
SpeculativeAlgorithm
]
=
None
@
triton
.
jit
@
triton
.
jit
def
write_req_to_token_pool_triton
(
def
write_req_to_token_pool_triton
(
...
...
python/sglang/srt/managers/tp_worker.py
View file @
b0524c37
...
@@ -150,12 +150,18 @@ class TpModelWorker:
...
@@ -150,12 +150,18 @@ class TpModelWorker:
self
,
self
,
model_worker_batch
:
ModelWorkerBatch
,
model_worker_batch
:
ModelWorkerBatch
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
launch_done
:
Optional
[
threading
.
Event
]
=
None
,
skip_sample
:
bool
=
False
,
):
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
if
launch_done
:
if
launch_done
:
launch_done
.
set
()
launch_done
.
set
()
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
model_worker_batch
)
if
skip_sample
:
next_token_ids
=
None
else
:
next_token_ids
=
self
.
model_runner
.
sample
(
logits_output
,
model_worker_batch
)
return
logits_output
,
next_token_ids
return
logits_output
,
next_token_ids
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
b0524c37
...
@@ -375,9 +375,7 @@ class CudaGraphRunner:
...
@@ -375,9 +375,7 @@ class CudaGraphRunner:
def
replay
(
self
,
forward_batch
:
ForwardBatch
):
def
replay
(
self
,
forward_batch
:
ForwardBatch
):
assert
forward_batch
.
out_cache_loc
is
not
None
assert
forward_batch
.
out_cache_loc
is
not
None
raw_bs
=
forward_batch
.
batch_size
raw_bs
=
forward_batch
.
batch_size
# In normal decoding case, raw_bs == raw_num_token
raw_num_token
=
raw_bs
*
self
.
num_tokens_per_bs
# But in speculative decoding, raw_num_token is raw_bs * self.num_tokens_per_bs
raw_num_token
=
forward_batch
.
input_ids
.
numel
()
# Pad
# Pad
if
self
.
enable_dp_attention
:
if
self
.
enable_dp_attention
:
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
b0524c37
...
@@ -96,7 +96,11 @@ class ForwardMode(IntEnum):
...
@@ -96,7 +96,11 @@ class ForwardMode(IntEnum):
return
self
==
ForwardMode
.
DRAFT_EXTEND
return
self
==
ForwardMode
.
DRAFT_EXTEND
def
is_cuda_graph
(
self
):
def
is_cuda_graph
(
self
):
return
self
==
ForwardMode
.
DECODE
or
self
==
ForwardMode
.
TARGET_VERIFY
return
(
self
==
ForwardMode
.
DECODE
or
self
==
ForwardMode
.
TARGET_VERIFY
or
self
==
ForwardMode
.
IDLE
)
def
is_dummy_first
(
self
):
def
is_dummy_first
(
self
):
return
self
==
ForwardMode
.
DUMMY_FIRST
return
self
==
ForwardMode
.
DUMMY_FIRST
...
@@ -161,15 +165,15 @@ class ForwardBatch:
...
@@ -161,15 +165,15 @@ class ForwardBatch:
token_to_kv_pool
:
BaseTokenToKVPool
=
None
token_to_kv_pool
:
BaseTokenToKVPool
=
None
attn_backend
:
AttentionBackend
=
None
attn_backend
:
AttentionBackend
=
None
# Speculative decoding
spec_info
:
SpecInfo
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
# For DP attention
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
can_run_dp_cuda_graph
:
bool
=
False
can_run_dp_cuda_graph
:
bool
=
False
# Speculative decoding
spec_info
:
SpecInfo
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
# For Qwen2-VL
# For Qwen2-VL
mrope_positions
:
torch
.
Tensor
=
None
mrope_positions
:
torch
.
Tensor
=
None
...
@@ -258,6 +262,8 @@ class ForwardBatch:
...
@@ -258,6 +262,8 @@ class ForwardBatch:
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
lora_paths
=
batch
.
lora_paths
,
lora_paths
=
batch
.
lora_paths
,
sampling_info
=
batch
.
sampling_info
,
sampling_info
=
batch
.
sampling_info
,
spec_algorithm
=
batch
.
spec_algorithm
,
spec_info
=
batch
.
spec_info
,
input_embeds
=
batch
.
input_embeds
,
input_embeds
=
batch
.
input_embeds
,
)
)
...
...
python/sglang/srt/server_args.py
View file @
b0524c37
...
@@ -108,14 +108,6 @@ class ServerArgs:
...
@@ -108,14 +108,6 @@ class ServerArgs:
# Model override args in JSON
# Model override args in JSON
json_model_override_args
:
str
=
"{}"
json_model_override_args
:
str
=
"{}"
# Double Sparsity
enable_double_sparsity
:
bool
=
False
ds_channel_config_path
:
str
=
None
ds_heavy_channel_num
:
int
=
32
ds_heavy_token_num
:
int
=
256
ds_heavy_channel_type
:
str
=
"qk"
ds_sparse_decode_threshold
:
int
=
4096
# LoRA
# LoRA
lora_paths
:
Optional
[
List
[
str
]]
=
None
lora_paths
:
Optional
[
List
[
str
]]
=
None
max_loras_per_batch
:
int
=
8
max_loras_per_batch
:
int
=
8
...
@@ -125,6 +117,21 @@ class ServerArgs:
...
@@ -125,6 +117,21 @@ class ServerArgs:
sampling_backend
:
Optional
[
str
]
=
None
sampling_backend
:
Optional
[
str
]
=
None
grammar_backend
:
Optional
[
str
]
=
"outlines"
grammar_backend
:
Optional
[
str
]
=
"outlines"
# Speculative decoding
speculative_draft_model_path
:
Optional
[
str
]
=
None
speculative_algorithm
:
Optional
[
str
]
=
None
speculative_num_steps
:
int
=
5
speculative_num_draft_tokens
:
int
=
64
speculative_eagle_topk
:
int
=
8
# Double Sparsity
enable_double_sparsity
:
bool
=
False
ds_channel_config_path
:
str
=
None
ds_heavy_channel_num
:
int
=
32
ds_heavy_token_num
:
int
=
256
ds_heavy_channel_type
:
str
=
"qk"
ds_sparse_decode_threshold
:
int
=
4096
# Optimization/debug options
# Optimization/debug options
disable_radix_cache
:
bool
=
False
disable_radix_cache
:
bool
=
False
disable_jump_forward
:
bool
=
False
disable_jump_forward
:
bool
=
False
...
@@ -602,43 +609,6 @@ class ServerArgs:
...
@@ -602,43 +609,6 @@ class ServerArgs:
default
=
ServerArgs
.
json_model_override_args
,
default
=
ServerArgs
.
json_model_override_args
,
)
)
# Double Sparsity
parser
.
add_argument
(
"--enable-double-sparsity"
,
action
=
"store_true"
,
help
=
"Enable double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-channel-config-path"
,
type
=
str
,
default
=
ServerArgs
.
ds_channel_config_path
,
help
=
"The path of the double sparsity channel config"
,
)
parser
.
add_argument
(
"--ds-heavy-channel-num"
,
type
=
int
,
default
=
ServerArgs
.
ds_heavy_channel_num
,
help
=
"The number of heavy channels in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-heavy-token-num"
,
type
=
int
,
default
=
ServerArgs
.
ds_heavy_token_num
,
help
=
"The number of heavy tokens in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-heavy-channel-type"
,
type
=
str
,
default
=
ServerArgs
.
ds_heavy_channel_type
,
help
=
"The type of heavy channels in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-sparse-decode-threshold"
,
type
=
int
,
default
=
ServerArgs
.
ds_sparse_decode_threshold
,
help
=
"The type of heavy channels in double sparsity attention"
,
)
# LoRA
# LoRA
parser
.
add_argument
(
parser
.
add_argument
(
"--lora-paths"
,
"--lora-paths"
,
...
@@ -678,6 +648,75 @@ class ServerArgs:
...
@@ -678,6 +648,75 @@ class ServerArgs:
help
=
"Choose the backend for grammar-guided decoding."
,
help
=
"Choose the backend for grammar-guided decoding."
,
)
)
# Speculative decoding
parser
.
add_argument
(
"--speculative-algorithm"
,
type
=
str
,
choices
=
[
"EAGLE"
],
help
=
"Speculative algorithm."
,
)
parser
.
add_argument
(
"--speculative-draft-model-path"
,
type
=
str
,
help
=
"The path of the draft model weights. This can be a local folder or a Hugging Face repo ID."
,
)
parser
.
add_argument
(
"--speculative-num-steps"
,
type
=
int
,
help
=
"The number of steps sampled from draft model in Speculative Decoding."
,
default
=
ServerArgs
.
speculative_num_steps
,
)
parser
.
add_argument
(
"--speculative-num-draft-tokens"
,
type
=
int
,
help
=
"The number of token sampled from draft model in Speculative Decoding."
,
default
=
ServerArgs
.
speculative_num_draft_tokens
,
)
parser
.
add_argument
(
"--speculative-eagle-topk"
,
type
=
int
,
help
=
"The number of token sampled from draft model in eagle2 each step."
,
choices
=
[
1
,
2
,
4
,
8
],
default
=
ServerArgs
.
speculative_eagle_topk
,
)
# Double Sparsity
parser
.
add_argument
(
"--enable-double-sparsity"
,
action
=
"store_true"
,
help
=
"Enable double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-channel-config-path"
,
type
=
str
,
default
=
ServerArgs
.
ds_channel_config_path
,
help
=
"The path of the double sparsity channel config"
,
)
parser
.
add_argument
(
"--ds-heavy-channel-num"
,
type
=
int
,
default
=
ServerArgs
.
ds_heavy_channel_num
,
help
=
"The number of heavy channels in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-heavy-token-num"
,
type
=
int
,
default
=
ServerArgs
.
ds_heavy_token_num
,
help
=
"The number of heavy tokens in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-heavy-channel-type"
,
type
=
str
,
default
=
ServerArgs
.
ds_heavy_channel_type
,
help
=
"The type of heavy channels in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-sparse-decode-threshold"
,
type
=
int
,
default
=
ServerArgs
.
ds_sparse_decode_threshold
,
help
=
"The type of heavy channels in double sparsity attention"
,
)
# Optimization/debug options
# Optimization/debug options
parser
.
add_argument
(
parser
.
add_argument
(
"--disable-radix-cache"
,
"--disable-radix-cache"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment