Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b8574f69
Unverified
Commit
b8574f69
authored
Jan 06, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 06, 2025
Browse files
Clean up eagle code (#2756)
parent
2855caa4
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
138 additions
and
128 deletions
+138
-128
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+1
-6
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+7
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+7
-4
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+7
-2
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+17
-15
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+54
-67
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+45
-33
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
b8574f69
...
@@ -74,11 +74,6 @@ class LogitsMetadata:
...
@@ -74,11 +74,6 @@ class LogitsMetadata:
@
classmethod
@
classmethod
def
from_forward_batch
(
cls
,
forward_batch
:
ForwardBatch
):
def
from_forward_batch
(
cls
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
spec_info
:
capture_hidden_mode
=
forward_batch
.
spec_info
.
capture_hidden_mode
else
:
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
if
forward_batch
.
forward_mode
.
is_extend
()
and
forward_batch
.
return_logprob
:
if
forward_batch
.
forward_mode
.
is_extend
()
and
forward_batch
.
return_logprob
:
extend_return_logprob
=
True
extend_return_logprob
=
True
extend_return_top_logprob
=
any
(
extend_return_top_logprob
=
any
(
...
@@ -98,7 +93,7 @@ class LogitsMetadata:
...
@@ -98,7 +93,7 @@ class LogitsMetadata:
return
cls
(
return
cls
(
forward_mode
=
forward_batch
.
forward_mode
,
forward_mode
=
forward_batch
.
forward_mode
,
capture_hidden_mode
=
capture_hidden_mode
,
capture_hidden_mode
=
forward_batch
.
capture_hidden_mode
,
extend_return_logprob
=
extend_return_logprob
,
extend_return_logprob
=
extend_return_logprob
,
extend_return_top_logprob
=
extend_return_top_logprob
,
extend_return_top_logprob
=
extend_return_top_logprob
,
extend_seq_lens
=
forward_batch
.
extend_seq_lens
,
extend_seq_lens
=
forward_batch
.
extend_seq_lens
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
b8574f69
...
@@ -44,7 +44,7 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
...
@@ -44,7 +44,7 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -1163,6 +1163,11 @@ class ScheduleBatch:
...
@@ -1163,6 +1163,11 @@ class ScheduleBatch:
input_embeds
=
self
.
input_embeds
,
input_embeds
=
self
.
input_embeds
,
spec_algorithm
=
self
.
spec_algorithm
,
spec_algorithm
=
self
.
spec_algorithm
,
spec_info
=
self
.
spec_info
,
spec_info
=
self
.
spec_info
,
capture_hidden_mode
=
(
getattr
(
self
.
spec_info
,
"capture_hidden_mode"
,
CaptureHiddenMode
.
NULL
)
if
self
.
spec_info
else
CaptureHiddenMode
.
NULL
),
)
)
def
copy
(
self
):
def
copy
(
self
):
...
@@ -1237,6 +1242,7 @@ class ModelWorkerBatch:
...
@@ -1237,6 +1242,7 @@ class ModelWorkerBatch:
# Speculative decoding
# Speculative decoding
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_info
:
Optional
[
SpecInfo
]
=
None
spec_info
:
Optional
[
SpecInfo
]
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
None
@
triton
.
jit
@
triton
.
jit
...
...
python/sglang/srt/managers/scheduler.py
View file @
b8574f69
...
@@ -962,10 +962,13 @@ class Scheduler:
...
@@ -962,10 +962,13 @@ class Scheduler:
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
)
)
else
:
else
:
logits_output
,
next_token_ids
,
model_worker_batch
,
spec_info
=
(
(
self
.
draft_worker
.
forward_batch_speculative_generation
(
batch
)
logits_output
,
)
next_token_ids
,
batch
.
spec_info
=
spec_info
model_worker_batch
,
num_accepted_tokens
,
)
=
self
.
draft_worker
.
forward_batch_speculative_generation
(
batch
)
self
.
num_generated_tokens
+=
num_accepted_tokens
elif
batch
.
forward_mode
.
is_idle
():
elif
batch
.
forward_mode
.
is_idle
():
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
self
.
tp_worker
.
forward_batch_idle
(
model_worker_batch
)
self
.
tp_worker
.
forward_batch_idle
(
model_worker_batch
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
b8574f69
...
@@ -322,6 +322,8 @@ class CudaGraphRunner:
...
@@ -322,6 +322,8 @@ class CudaGraphRunner:
global_num_tokens
=
None
global_num_tokens
=
None
gathered_buffer
=
None
gathered_buffer
=
None
spec_info
=
self
.
get_spec_info
(
num_tokens
,
positions
)
forward_batch
=
ForwardBatch
(
forward_batch
=
ForwardBatch
(
forward_mode
=
self
.
capture_forward_mode
,
forward_mode
=
self
.
capture_forward_mode
,
batch_size
=
bs
,
batch_size
=
bs
,
...
@@ -341,7 +343,10 @@ class CudaGraphRunner:
...
@@ -341,7 +343,10 @@ class CudaGraphRunner:
mrope_positions
=
mrope_positions
,
mrope_positions
=
mrope_positions
,
gathered_buffer
=
gathered_buffer
,
gathered_buffer
=
gathered_buffer
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_info
=
self
.
get_spec_info
(
num_tokens
,
positions
),
spec_info
=
spec_info
,
capture_hidden_mode
=
(
spec_info
.
capture_hidden_mode
if
spec_info
else
CaptureHiddenMode
.
NULL
),
)
)
# Attention backend
# Attention backend
...
@@ -446,10 +451,10 @@ class CudaGraphRunner:
...
@@ -446,10 +451,10 @@ class CudaGraphRunner:
if
self
.
model_runner
.
is_draft_worker
:
if
self
.
model_runner
.
is_draft_worker
:
spec_info
=
EAGLEDraftInput
()
spec_info
=
EAGLEDraftInput
()
spec_info
.
load_server_args
(
self
.
model_runner
.
server_args
)
spec_info
.
hidden_states
=
self
.
hidden_states
[:
num_tokens
]
spec_info
.
hidden_states
=
self
.
hidden_states
[:
num_tokens
]
spec_info
.
positions
=
positions
spec_info
.
positions
=
positions
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
spec_info
.
init
(
self
.
model_runner
.
server_args
)
else
:
else
:
spec_info
=
EagleVerifyInput
(
spec_info
=
EagleVerifyInput
(
None
,
None
,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
b8574f69
...
@@ -107,6 +107,21 @@ class ForwardMode(IntEnum):
...
@@ -107,6 +107,21 @@ class ForwardMode(IntEnum):
return
self
==
ForwardMode
.
DUMMY_FIRST
return
self
==
ForwardMode
.
DUMMY_FIRST
class
CaptureHiddenMode
(
IntEnum
):
NULL
=
auto
()
FULL
=
auto
()
LAST
=
auto
()
def
need_capture
(
self
):
return
self
!=
CaptureHiddenMode
.
NULL
def
is_full
(
self
):
return
self
==
CaptureHiddenMode
.
FULL
def
is_last
(
self
):
return
self
==
CaptureHiddenMode
.
LAST
@
dataclass
@
dataclass
class
ForwardBatch
:
class
ForwardBatch
:
"""Store all inputs of a forward pass."""
"""Store all inputs of a forward pass."""
...
@@ -174,6 +189,7 @@ class ForwardBatch:
...
@@ -174,6 +189,7 @@ class ForwardBatch:
# Speculative decoding
# Speculative decoding
spec_info
:
SpecInfo
=
None
spec_info
:
SpecInfo
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
None
# For Qwen2-VL
# For Qwen2-VL
mrope_positions
:
torch
.
Tensor
=
None
mrope_positions
:
torch
.
Tensor
=
None
...
@@ -265,6 +281,7 @@ class ForwardBatch:
...
@@ -265,6 +281,7 @@ class ForwardBatch:
sampling_info
=
batch
.
sampling_info
,
sampling_info
=
batch
.
sampling_info
,
spec_algorithm
=
batch
.
spec_algorithm
,
spec_algorithm
=
batch
.
spec_algorithm
,
spec_info
=
batch
.
spec_info
,
spec_info
=
batch
.
spec_info
,
capture_hidden_mode
=
batch
.
capture_hidden_mode
,
input_embeds
=
batch
.
input_embeds
,
input_embeds
=
batch
.
input_embeds
,
)
)
...
@@ -400,18 +417,3 @@ def compute_position_torch(
...
@@ -400,18 +417,3 @@ def compute_position_torch(
@
maybe_torch_compile
(
dynamic
=
True
)
@
maybe_torch_compile
(
dynamic
=
True
)
def
clamp_position
(
seq_lens
):
def
clamp_position
(
seq_lens
):
return
torch
.
clamp
((
seq_lens
-
1
),
min
=
0
).
to
(
torch
.
int64
)
return
torch
.
clamp
((
seq_lens
-
1
),
min
=
0
).
to
(
torch
.
int64
)
class
CaptureHiddenMode
(
IntEnum
):
NULL
=
auto
()
FULL
=
auto
()
LAST
=
auto
()
def
need_capture
(
self
):
return
self
!=
CaptureHiddenMode
.
NULL
def
is_full
(
self
):
return
self
==
CaptureHiddenMode
.
FULL
def
is_last
(
self
):
return
self
==
CaptureHiddenMode
.
LAST
python/sglang/srt/speculative/eagle_utils.py
View file @
b8574f69
...
@@ -9,12 +9,11 @@ import triton.language as tl
...
@@ -9,12 +9,11 @@ import triton.language as tl
from
sglang.srt.layers.attention.flashinfer_backend
import
(
from
sglang.srt.layers.attention.flashinfer_backend
import
(
create_flashinfer_kv_indices_triton
,
create_flashinfer_kv_indices_triton
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel
from
sglang.srt.speculative.spec_info
import
SpecInfo
from
sglang.srt.speculative.spec_info
import
SpecInfo
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
python.sglang.srt.layers.sampler
import
SampleOutput
from
python.sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
python.sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -179,19 +178,9 @@ def generate_draft_decode_kv_indices(
...
@@ -179,19 +178,9 @@ def generate_draft_decode_kv_indices(
class
EAGLEDraftInput
(
SpecInfo
):
class
EAGLEDraftInput
(
SpecInfo
):
hidden_states
:
torch
.
Tensor
=
None
def
__init__
(
self
):
verified_id
:
torch
.
Tensor
=
None
positions
:
torch
.
Tensor
=
None
accept_length
:
torch
.
Tensor
=
None
has_finished
:
bool
=
False
unfinished_index
:
List
[
int
]
=
None
def
init
(
self
,
server_args
:
ServerArgs
):
self
.
prev_mode
=
ForwardMode
.
DECODE
self
.
prev_mode
=
ForwardMode
.
DECODE
self
.
sample_output
=
None
self
.
sample_output
=
None
self
.
topk
:
int
=
server_args
.
speculative_eagle_topk
self
.
num_verify_token
:
int
=
server_args
.
speculative_num_draft_tokens
self
.
spec_steps
=
server_args
.
speculative_num_steps
self
.
scores
:
torch
.
Tensor
=
None
self
.
scores
:
torch
.
Tensor
=
None
self
.
score_list
:
List
[
torch
.
Tensor
]
=
[]
self
.
score_list
:
List
[
torch
.
Tensor
]
=
[]
...
@@ -200,11 +189,20 @@ class EAGLEDraftInput(SpecInfo):
...
@@ -200,11 +189,20 @@ class EAGLEDraftInput(SpecInfo):
self
.
parents_list
:
List
[
torch
.
Tensor
]
=
[]
self
.
parents_list
:
List
[
torch
.
Tensor
]
=
[]
self
.
cache_list
:
List
[
torch
.
Tenor
]
=
[]
self
.
cache_list
:
List
[
torch
.
Tenor
]
=
[]
self
.
iter
=
0
self
.
iter
=
0
self
.
root_token
:
int
=
None
assert
self
.
topk
<=
10
,
"topk should <= 10"
self
.
hidden_states
:
torch
.
Tensor
=
None
self
.
verified_id
:
torch
.
Tensor
=
None
self
.
positions
:
torch
.
Tensor
=
None
self
.
accept_length
:
torch
.
Tensor
=
None
self
.
has_finished
:
bool
=
False
self
.
unfinished_index
:
List
[
int
]
=
None
def
load_server_args
(
self
,
server_args
:
ServerArgs
):
self
.
topk
:
int
=
server_args
.
speculative_eagle_topk
self
.
num_verify_token
:
int
=
server_args
.
speculative_num_draft_tokens
self
.
spec_steps
=
server_args
.
speculative_num_steps
def
prepare_for_extend
(
self
,
batch
:
Forward
Batch
):
def
prepare_for_extend
(
self
,
batch
:
Schedule
Batch
):
req_pool_indices
=
batch
.
alloc_req_slots
(
len
(
batch
.
reqs
))
req_pool_indices
=
batch
.
alloc_req_slots
(
len
(
batch
.
reqs
))
out_cache_loc
=
batch
.
alloc_token_slots
(
batch
.
input_ids
.
numel
())
out_cache_loc
=
batch
.
alloc_token_slots
(
batch
.
input_ids
.
numel
())
batch
.
out_cache_loc
=
out_cache_loc
batch
.
out_cache_loc
=
out_cache_loc
...
@@ -226,81 +224,72 @@ class EAGLEDraftInput(SpecInfo):
...
@@ -226,81 +224,72 @@ class EAGLEDraftInput(SpecInfo):
pt
+=
req
.
extend_input_len
pt
+=
req
.
extend_input_len
seq_lens
=
[
0
]
+
batch
.
extend_lens
# TODO: support batching inputs
input_ids
=
batch
.
input_ids
.
tolist
()
assert
len
(
batch
.
extend_lens
)
==
1
verified_id
=
batch
.
spec_info
.
verified_id
.
tolist
()
batch
.
input_ids
=
torch
.
concat
((
batch
.
input_ids
[
1
:],
self
.
verified_id
))
model_input_ids
=
[]
for
i
in
range
(
len
(
seq_lens
)
-
1
):
model_input_ids
.
extend
(
input_ids
[
seq_lens
[
i
]
+
1
:
seq_lens
[
i
+
1
]]
+
[
verified_id
[
i
]]
)
batch
.
input_ids
=
torch
.
tensor
(
model_input_ids
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
def
capture_for_decode
(
self
,
sample_output
:
SampleOutput
,
hidden_states
:
torch
.
Tensor
,
prev_mode
:
ForwardMode
,
):
self
.
sample_output
=
sample_output
self
.
prev_mode
=
prev_mode
self
.
hidden_states
=
hidden_states
def
prepare_for_decode
(
self
,
batch
:
ScheduleBatch
):
def
prepare_for_decode
(
self
,
batch
:
ScheduleBatch
):
prob
=
self
.
sample_output
#
b * (1/topk)
, vocab
prob
=
self
.
sample_output
#
shape: (b * top_k, vocab) or (b
, vocab
)
top
=
torch
.
topk
(
prob
,
self
.
topk
,
dim
=-
1
)
top
=
torch
.
topk
(
prob
,
self
.
topk
,
dim
=-
1
)
topk_index
,
topk_p
=
top
.
indices
,
top
.
values
# b * (1/topk), topk
topk_index
,
topk_p
=
(
if
self
.
prev_mode
==
ForwardMode
.
DECODE
:
top
.
indices
,
top
.
values
,
)
# shape: (b * top_k, top_k) or (b, top_k)
if
self
.
prev_mode
.
is_decode
():
scores
=
torch
.
mul
(
scores
=
torch
.
mul
(
self
.
scores
.
unsqueeze
(
2
),
topk_p
.
reshape
(
-
1
,
self
.
topk
,
self
.
topk
)
self
.
scores
.
unsqueeze
(
2
),
topk_p
.
reshape
(
-
1
,
self
.
topk
,
self
.
topk
)
)
# (b, topk
) mul
(b
*
topk ,topk) -> b, topk, topk
)
# (b, topk
, 1) x
(b
,
topk ,topk) ->
(
b, topk, topk
)
topk_cs
=
torch
.
topk
(
topk_cs
=
torch
.
topk
(
scores
.
flatten
(
start_dim
=
1
),
self
.
topk
,
dim
=-
1
scores
.
flatten
(
start_dim
=
1
),
self
.
topk
,
dim
=-
1
)
# (b, topk)
)
# (b, topk)
topk_cs_index
,
topk_cs_p
=
topk_cs
.
indices
,
topk_cs
.
values
topk_cs_index
,
topk_cs_p
=
topk_cs
.
indices
,
topk_cs
.
values
self
.
scores
=
topk_cs_p
selected_input_index
=
topk_cs_index
.
flatten
()
//
self
.
topk
# b* topk
selected_input_index
=
(
topk_cs_index
.
flatten
()
//
self
.
topk
)
# shape: (b * topk)
batch
.
spec_info
.
hidden_states
=
batch
.
spec_info
.
hidden_states
[
batch
.
spec_info
.
hidden_states
=
batch
.
spec_info
.
hidden_states
[
selected_input_index
,
:
selected_input_index
,
:
]
]
topk_index
=
topk_index
.
reshape
(
-
1
,
self
.
topk
**
2
)
topk_index
=
topk_index
.
reshape
(
-
1
,
self
.
topk
**
2
)
batch
.
input_ids
=
torch
.
gather
(
batch
.
input_ids
=
torch
.
gather
(
topk_index
,
index
=
topk_cs_index
,
dim
=
1
topk_index
,
index
=
topk_cs_index
,
dim
=
1
).
flatten
()
).
flatten
()
batch
.
out_cache_loc
=
batch
.
alloc_token_slots
(
batch
.
input_ids
.
numel
())
batch
.
out_cache_loc
=
batch
.
alloc_token_slots
(
len
(
batch
.
input_ids
))
self
.
score_list
.
append
(
scores
)
# b, topk, topk
self
.
token_list
.
append
(
topk_index
)
# b, topk*topk
self
.
scores
=
topk_cs_p
self
.
score_list
.
append
(
scores
)
# (b, topk, topk)
self
.
token_list
.
append
(
topk_index
)
# (b, topk * topk)
self
.
origin_score_list
.
append
(
topk_p
.
reshape
(
topk_index
.
shape
))
self
.
origin_score_list
.
append
(
topk_p
.
reshape
(
topk_index
.
shape
))
self
.
parents_list
.
append
(
self
.
parents_list
.
append
(
topk_cs_index
+
(
self
.
topk
**
2
*
(
self
.
iter
-
1
)
+
self
.
topk
)
topk_cs_index
+
(
self
.
topk
**
2
*
(
self
.
iter
-
1
)
+
self
.
topk
)
)
# b, topk
)
# shape: (b, topk)
else
:
elif
self
.
prev_mode
in
(
ForwardMode
.
EXTEND
,
ForwardMode
.
DRAFT_EXTEND
):
# ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND
self
.
scores
=
topk_p
# b, top_k
self
.
score_list
.
append
(
topk_p
.
unsqueeze
(
1
))
self
.
token_list
.
append
(
topk_index
)
self
.
origin_score_list
.
append
(
topk_p
)
batch
.
spec_info
.
hidden_states
=
(
batch
.
spec_info
.
hidden_states
=
(
batch
.
spec_info
.
hidden_states
.
repeat_interleave
(
self
.
topk
,
0
)
batch
.
spec_info
.
hidden_states
.
repeat_interleave
(
self
.
topk
,
dim
=
0
)
)
)
batch
.
input_ids
=
topk_index
.
flatten
()
batch
.
input_ids
=
topk_index
.
flatten
()
batch
.
out_cache_loc
=
batch
.
alloc_token_slots
(
topk_index
.
numel
())
batch
.
out_cache_loc
=
batch
.
alloc_token_slots
(
topk_index
.
numel
())
self
.
scores
=
topk_p
# shape: (b, topk)
self
.
score_list
.
append
(
topk_p
.
unsqueeze
(
1
))
# shape: (b, 1, topk)
self
.
token_list
.
append
(
topk_index
)
# shape: (b, topk)
self
.
origin_score_list
.
append
(
topk_p
)
self
.
parents_list
.
append
(
self
.
parents_list
.
append
(
torch
.
arange
(
-
1
,
self
.
topk
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
torch
.
arange
(
-
1
,
self
.
topk
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
.
unsqueeze
(
0
)
.
unsqueeze
(
0
)
.
repeat
(
self
.
scores
.
shape
[
0
],
1
)
.
repeat
(
self
.
scores
.
shape
[
0
],
1
)
)
# b, topk
+1
)
#
shape: (
b, topk
+ 1)
self
.
cache_list
.
append
(
batch
.
out_cache_loc
)
self
.
cache_list
.
append
(
batch
.
out_cache_loc
)
self
.
positions
=
(
self
.
positions
=
(
batch
.
seq_lens
[:,
None
]
batch
.
seq_lens
[:,
None
]
+
torch
.
ones
([
1
,
self
.
topk
],
device
=
"cuda"
,
dtype
=
torch
.
long
)
*
self
.
iter
+
torch
.
ones
([
1
,
self
.
topk
],
device
=
"cuda"
,
dtype
=
torch
.
long
)
*
self
.
iter
).
flatten
()
).
flatten
()
bs
=
batch
.
seq_lens
.
numel
(
)
bs
=
len
(
batch
.
seq_lens
)
assign_req_to_token_pool
[(
bs
,)](
assign_req_to_token_pool
[(
bs
,)](
batch
.
req_pool_indices
,
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
req_to_token_pool
.
req_to_token
,
...
@@ -419,11 +408,6 @@ class EAGLEDraftInput(SpecInfo):
...
@@ -419,11 +408,6 @@ class EAGLEDraftInput(SpecInfo):
)
)
return
bs
,
kv_indices
,
cum_kv_seq_len
return
bs
,
kv_indices
,
cum_kv_seq_len
def
clear
(
self
):
self
.
iter
=
0
self
.
score_list
.
clear
()
self
.
positions
=
None
def
clear_draft_cache
(
self
,
batch
):
def
clear_draft_cache
(
self
,
batch
):
draft_cache
=
torch
.
cat
(
self
.
cache_list
,
dim
=
0
)
draft_cache
=
torch
.
cat
(
self
.
cache_list
,
dim
=
0
)
batch
.
token_to_kv_pool
.
free
(
draft_cache
)
batch
.
token_to_kv_pool
.
free
(
draft_cache
)
...
@@ -460,7 +444,6 @@ class EAGLEDraftInput(SpecInfo):
...
@@ -460,7 +444,6 @@ class EAGLEDraftInput(SpecInfo):
[
self
.
hidden_states
,
spec_info
.
hidden_states
],
axis
=
0
[
self
.
hidden_states
,
spec_info
.
hidden_states
],
axis
=
0
)
)
self
.
verified_id
=
torch
.
cat
([
self
.
verified_id
,
spec_info
.
verified_id
],
axis
=
0
)
self
.
verified_id
=
torch
.
cat
([
self
.
verified_id
,
spec_info
.
verified_id
],
axis
=
0
)
# self.positions = torch.cat([self.positions, spec_info.positions], axis=0)
self
.
sample_output
=
torch
.
cat
([
self
.
sample_output
,
spec_info
.
sample_output
])
self
.
sample_output
=
torch
.
cat
([
self
.
sample_output
,
spec_info
.
sample_output
])
...
@@ -568,9 +551,6 @@ class EagleVerifyInput(SpecInfo):
...
@@ -568,9 +551,6 @@ class EagleVerifyInput(SpecInfo):
)
)
accept_index
=
accept_index
[
accept_index
!=
-
1
]
accept_index
=
accept_index
[
accept_index
!=
-
1
]
# extract_index = extract_index[extract_index != 0]
draft_input
=
EAGLEDraftInput
()
accept_length_cpu
=
accept_length
.
tolist
()
accept_length_cpu
=
accept_length
.
tolist
()
verified_id
=
predict
[
accept_index
]
verified_id
=
predict
[
accept_index
]
...
@@ -596,6 +576,7 @@ class EagleVerifyInput(SpecInfo):
...
@@ -596,6 +576,7 @@ class EagleVerifyInput(SpecInfo):
# retracted_reqs, new_token_ratio = batch.retract_decode()
# retracted_reqs, new_token_ratio = batch.retract_decode()
low
=
0
low
=
0
draft_input
=
EAGLEDraftInput
()
for
i
,
(
req
,
verified_len
)
in
enumerate
(
zip
(
batch
.
reqs
,
accept_length_cpu
)):
for
i
,
(
req
,
verified_len
)
in
enumerate
(
zip
(
batch
.
reqs
,
accept_length_cpu
)):
req
.
output_ids
.
extend
(
verified_id_cpu
[
low
:
low
+
verified_len
+
1
])
req
.
output_ids
.
extend
(
verified_id_cpu
[
low
:
low
+
verified_len
+
1
])
req
.
check_finished
()
req
.
check_finished
()
...
@@ -615,4 +596,10 @@ class EagleVerifyInput(SpecInfo):
...
@@ -615,4 +596,10 @@ class EagleVerifyInput(SpecInfo):
draft_input
.
unfinished_index
=
unfinished_index
draft_input
.
unfinished_index
=
unfinished_index
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[
accept_index
]
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[
accept_index
]
return
draft_input
,
logits_output
,
verified_id
,
finished_extend_len
return
(
draft_input
,
logits_output
,
verified_id
,
finished_extend_len
,
accept_length_cpu
,
)
python/sglang/srt/speculative/eagle_worker.py
View file @
b8574f69
...
@@ -51,63 +51,72 @@ class EAGLEWorker(TpModelWorker):
...
@@ -51,63 +51,72 @@ class EAGLEWorker(TpModelWorker):
batch
.
spec_info
.
prepare_for_decode
(
batch
)
batch
.
spec_info
.
prepare_for_decode
(
batch
)
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
forward_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
def
forward_draft_extend
(
self
,
batch
:
ScheduleBatch
):
def
forward_draft_extend
(
self
,
batch
:
ScheduleBatch
):
self
.
_s
wap
_mem_pool
(
batch
,
self
.
model_runner
)
self
.
_s
et
_mem_pool
(
batch
,
self
.
model_runner
)
batch
.
spec_info
.
prepare_for_extend
(
batch
)
batch
.
spec_info
.
prepare_for_extend
(
batch
)
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
forward_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
self
.
_s
wap
_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
self
.
_s
et
_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
def
forward_batch_speculative_generation
(
self
,
batch
:
ScheduleBatch
):
def
forward_batch_speculative_generation
(
self
,
batch
:
ScheduleBatch
):
if
batch
.
forward_mode
.
is_decode
():
if
batch
.
forward_mode
.
is_decode
():
prev_spec_info
=
batch
.
spec_info
# Draft
self
.
_s
wap
_mem_pool
(
batch
,
self
.
model_runner
)
self
.
_s
et
_mem_pool
(
batch
,
self
.
model_runner
)
for
i
in
range
(
self
.
server_args
.
speculative_num_steps
):
for
i
in
range
(
self
.
server_args
.
speculative_num_steps
):
self
.
forward_draft_decode
(
batch
)
self
.
forward_draft_decode
(
batch
)
batch
.
spec_info
.
clear_draft_cache
(
batch
)
batch
.
spec_info
.
clear_draft_cache
(
batch
)
self
.
_swap_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
self
.
_set_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
# Verify
(
(
next_draft_input
,
next_draft_input
,
logits_output
,
logits_output
,
verified_id
,
verified_id
,
self
.
finish_extend_len
,
self
.
finish_extend_len
,
accept_length_cpu
,
model_worker_batch
,
model_worker_batch
,
)
=
self
.
verify
(
batch
)
)
=
self
.
verify
(
batch
)
next_draft_input
.
init
(
self
.
server_args
)
next_draft_input
.
load_server_args
(
self
.
server_args
)
batch
.
spec_info
=
next_draft_input
batch
.
spec_info
=
next_draft_input
# if it is None, means all requsets are finished
# if it is None, means all requsets are finished
if
batch
.
spec_info
.
verified_id
is
not
None
:
if
batch
.
spec_info
.
verified_id
is
not
None
:
self
.
forward_extend_after_decode
(
batch
)
self
.
forward_draft_extend_after_decode
(
batch
)
batch
.
spec_info
=
prev_spec_info
return
(
return
logits_output
,
verified_id
,
model_worker_batch
,
next_draft_input
logits_output
,
verified_id
,
model_worker_batch
,
sum
(
accept_length_cpu
),
)
else
:
else
:
spec_info
=
EAGLEDraftInput
()
# Forward with the target model and get hidden states.
spec_info
.
init
(
self
.
server_args
)
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
.
spec_info
=
spec_info
model_worker_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
logits_output
,
next_token_ids
=
self
.
target_worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
model_worker_batch
)
)
model_worker_batch
.
spec_info
.
verified_id
=
next_token_ids
model_worker_batch
.
spec_info
.
hidden_states
=
logits_output
.
hidden_states
# Forward with the draft model.
spec_info
=
EAGLEDraftInput
()
spec_info
.
load_server_args
(
self
.
server_args
)
spec_info
.
hidden_states
=
logits_output
.
hidden_states
spec_info
.
verified_id
=
next_token_ids
batch
.
spec_info
=
spec_info
batch
.
spec_info
=
spec_info
self
.
forward_draft_extend
(
batch
)
self
.
forward_draft_extend
(
batch
)
batch
.
spec_info
=
None
return
logits_output
,
next_token_ids
,
model_worker_batch
,
0
return
logits_output
,
next_token_ids
,
model_worker_batch
,
spec_info
def
verify
(
self
,
batch
:
ScheduleBatch
):
def
verify
(
self
,
batch
:
ScheduleBatch
):
verify_input
=
batch
.
spec_info
.
prepare_for_verify
(
batch
)
verify_input
=
batch
.
spec_info
.
prepare_for_verify
(
batch
)
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
verify_input
.
prepare_for_verify
(
batch
)
verify_input
.
prepare_for_verify
(
batch
)
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
batch
.
spec_info
=
verify_input
batch
.
spec_info
=
verify_input
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
@@ -119,38 +128,41 @@ class EAGLEWorker(TpModelWorker):
...
@@ -119,38 +128,41 @@ class EAGLEWorker(TpModelWorker):
batch
.
forward_mode
=
ForwardMode
.
DECODE
batch
.
forward_mode
=
ForwardMode
.
DECODE
return
res
+
(
model_worker_batch
,)
return
res
+
(
model_worker_batch
,)
def
_s
wap
_mem_pool
(
self
,
batch
:
ScheduleBatch
,
runner
:
ModelRunner
):
def
_s
et
_mem_pool
(
self
,
batch
:
ScheduleBatch
,
runner
:
ModelRunner
):
batch
.
token_to_kv_pool
=
runner
.
token_to_kv_pool
batch
.
token_to_kv_pool
=
runner
.
token_to_kv_pool
batch
.
req_to_token_pool
=
runner
.
req_to_token_pool
batch
.
req_to_token_pool
=
runner
.
req_to_token_pool
def
forward_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
def
forward_
draft_
extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
self
.
_s
wap
_mem_pool
(
batch
,
self
.
model_runner
)
self
.
_s
et
_mem_pool
(
batch
,
self
.
model_runner
)
batch
.
forward_mode
=
ForwardMode
.
DRAFT_EXTEND
batch
.
forward_mode
=
ForwardMode
.
DRAFT_EXTEND
if
batch
.
spec_info
.
has_finished
:
if
batch
.
spec_info
.
has_finished
:
index
=
batch
.
spec_info
.
unfinished_index
index
=
batch
.
spec_info
.
unfinished_index
seq_lens
=
batch
.
seq_lens
seq_lens
=
batch
.
seq_lens
batch
.
seq_lens
=
batch
.
seq_lens
[
index
]
batch
.
seq_lens
=
batch
.
seq_lens
[
index
]
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
)
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
)
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
forward_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
batch
.
spec_info
.
hidden_states
=
logits_output
.
hidden_states
batch
.
spec_info
.
hidden_states
=
logits_output
.
hidden_states
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
batch
.
forward_mode
=
ForwardMode
.
DECODE
batch
.
forward_mode
=
ForwardMode
.
DECODE
if
batch
.
spec_info
.
has_finished
:
if
batch
.
spec_info
.
has_finished
:
batch
.
seq_lens
=
seq_lens
batch
.
seq_lens
=
seq_lens
self
.
_s
wap
_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
self
.
_s
et
_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
def
capture_for_decode
(
self
,
logits_output
,
forward_batch
):
def
capture_for_decode
(
if
isinstance
(
logits_output
,
LogitsProcessorOutput
):
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
logits
=
logits_output
.
next_token_logits
):
sample_output
=
torch
.
softmax
(
sample_output
=
torch
.
softmax
(
logits
,
dim
=-
1
logits_output
.
next_token_logits
,
dim
=-
1
)
# TODO: Support more sampling method @kavioyu
)
# TODO(kavioyu): Support more sampling methods
forward_batch
.
spec_info
.
capture_for_decode
(
spec_info
=
forward_batch
.
spec_info
sample_output
,
logits_output
.
hidden_states
,
forward_batch
.
forward_mode
spec_info
.
sample_output
=
sample_output
)
spec_info
.
hidden_states
=
logits_output
.
hidden_states
spec_info
.
prev_mode
=
forward_batch
.
forward_mode
# Don't support prefix share now.
# Don't support prefix share now.
def
finish_request
(
self
,
reqs
:
Union
[
Req
,
List
[
Req
]]):
def
finish_request
(
self
,
reqs
:
Union
[
Req
,
List
[
Req
]]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment