Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b8574f69
Unverified
Commit
b8574f69
authored
Jan 06, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 06, 2025
Browse files
Clean up eagle code (#2756)
parent
2855caa4
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
138 additions
and
128 deletions
+138
-128
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+1
-6
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+7
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+7
-4
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+7
-2
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+17
-15
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+54
-67
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+45
-33
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
b8574f69
...
...
@@ -74,11 +74,6 @@ class LogitsMetadata:
@
classmethod
def
from_forward_batch
(
cls
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
spec_info
:
capture_hidden_mode
=
forward_batch
.
spec_info
.
capture_hidden_mode
else
:
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
if
forward_batch
.
forward_mode
.
is_extend
()
and
forward_batch
.
return_logprob
:
extend_return_logprob
=
True
extend_return_top_logprob
=
any
(
...
...
@@ -98,7 +93,7 @@ class LogitsMetadata:
return
cls
(
forward_mode
=
forward_batch
.
forward_mode
,
capture_hidden_mode
=
capture_hidden_mode
,
capture_hidden_mode
=
forward_batch
.
capture_hidden_mode
,
extend_return_logprob
=
extend_return_logprob
,
extend_return_top_logprob
=
extend_return_top_logprob
,
extend_seq_lens
=
forward_batch
.
extend_seq_lens
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
b8574f69
...
...
@@ -44,7 +44,7 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -1163,6 +1163,11 @@ class ScheduleBatch:
input_embeds
=
self
.
input_embeds
,
spec_algorithm
=
self
.
spec_algorithm
,
spec_info
=
self
.
spec_info
,
capture_hidden_mode
=
(
getattr
(
self
.
spec_info
,
"capture_hidden_mode"
,
CaptureHiddenMode
.
NULL
)
if
self
.
spec_info
else
CaptureHiddenMode
.
NULL
),
)
def
copy
(
self
):
...
...
@@ -1237,6 +1242,7 @@ class ModelWorkerBatch:
# Speculative decoding
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_info
:
Optional
[
SpecInfo
]
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
None
@
triton
.
jit
...
...
python/sglang/srt/managers/scheduler.py
View file @
b8574f69
...
...
@@ -962,10 +962,13 @@ class Scheduler:
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
)
else
:
logits_output
,
next_token_ids
,
model_worker_batch
,
spec_info
=
(
self
.
draft_worker
.
forward_batch_speculative_generation
(
batch
)
)
batch
.
spec_info
=
spec_info
(
logits_output
,
next_token_ids
,
model_worker_batch
,
num_accepted_tokens
,
)
=
self
.
draft_worker
.
forward_batch_speculative_generation
(
batch
)
self
.
num_generated_tokens
+=
num_accepted_tokens
elif
batch
.
forward_mode
.
is_idle
():
model_worker_batch
=
batch
.
get_model_worker_batch
()
self
.
tp_worker
.
forward_batch_idle
(
model_worker_batch
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
b8574f69
...
...
@@ -322,6 +322,8 @@ class CudaGraphRunner:
global_num_tokens
=
None
gathered_buffer
=
None
spec_info
=
self
.
get_spec_info
(
num_tokens
,
positions
)
forward_batch
=
ForwardBatch
(
forward_mode
=
self
.
capture_forward_mode
,
batch_size
=
bs
,
...
...
@@ -341,7 +343,10 @@ class CudaGraphRunner:
mrope_positions
=
mrope_positions
,
gathered_buffer
=
gathered_buffer
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_info
=
self
.
get_spec_info
(
num_tokens
,
positions
),
spec_info
=
spec_info
,
capture_hidden_mode
=
(
spec_info
.
capture_hidden_mode
if
spec_info
else
CaptureHiddenMode
.
NULL
),
)
# Attention backend
...
...
@@ -446,10 +451,10 @@ class CudaGraphRunner:
if
self
.
model_runner
.
is_draft_worker
:
spec_info
=
EAGLEDraftInput
()
spec_info
.
load_server_args
(
self
.
model_runner
.
server_args
)
spec_info
.
hidden_states
=
self
.
hidden_states
[:
num_tokens
]
spec_info
.
positions
=
positions
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
spec_info
.
init
(
self
.
model_runner
.
server_args
)
else
:
spec_info
=
EagleVerifyInput
(
None
,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
b8574f69
...
...
@@ -107,6 +107,21 @@ class ForwardMode(IntEnum):
return
self
==
ForwardMode
.
DUMMY_FIRST
class
CaptureHiddenMode
(
IntEnum
):
NULL
=
auto
()
FULL
=
auto
()
LAST
=
auto
()
def
need_capture
(
self
):
return
self
!=
CaptureHiddenMode
.
NULL
def
is_full
(
self
):
return
self
==
CaptureHiddenMode
.
FULL
def
is_last
(
self
):
return
self
==
CaptureHiddenMode
.
LAST
@
dataclass
class
ForwardBatch
:
"""Store all inputs of a forward pass."""
...
...
@@ -174,6 +189,7 @@ class ForwardBatch:
# Speculative decoding
spec_info
:
SpecInfo
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
None
# For Qwen2-VL
mrope_positions
:
torch
.
Tensor
=
None
...
...
@@ -265,6 +281,7 @@ class ForwardBatch:
sampling_info
=
batch
.
sampling_info
,
spec_algorithm
=
batch
.
spec_algorithm
,
spec_info
=
batch
.
spec_info
,
capture_hidden_mode
=
batch
.
capture_hidden_mode
,
input_embeds
=
batch
.
input_embeds
,
)
...
...
@@ -400,18 +417,3 @@ def compute_position_torch(
@
maybe_torch_compile
(
dynamic
=
True
)
def
clamp_position
(
seq_lens
):
return
torch
.
clamp
((
seq_lens
-
1
),
min
=
0
).
to
(
torch
.
int64
)
class
CaptureHiddenMode
(
IntEnum
):
NULL
=
auto
()
FULL
=
auto
()
LAST
=
auto
()
def
need_capture
(
self
):
return
self
!=
CaptureHiddenMode
.
NULL
def
is_full
(
self
):
return
self
==
CaptureHiddenMode
.
FULL
def
is_last
(
self
):
return
self
==
CaptureHiddenMode
.
LAST
python/sglang/srt/speculative/eagle_utils.py
View file @
b8574f69
...
...
@@ -9,12 +9,11 @@ import triton.language as tl
from
sglang.srt.layers.attention.flashinfer_backend
import
(
create_flashinfer_kv_indices_triton
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel
from
sglang.srt.speculative.spec_info
import
SpecInfo
if
TYPE_CHECKING
:
from
python.sglang.srt.layers.sampler
import
SampleOutput
from
python.sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -179,19 +178,9 @@ def generate_draft_decode_kv_indices(
class
EAGLEDraftInput
(
SpecInfo
):
hidden_states
:
torch
.
Tensor
=
None
verified_id
:
torch
.
Tensor
=
None
positions
:
torch
.
Tensor
=
None
accept_length
:
torch
.
Tensor
=
None
has_finished
:
bool
=
False
unfinished_index
:
List
[
int
]
=
None
def
init
(
self
,
server_args
:
ServerArgs
):
def
__init__
(
self
):
self
.
prev_mode
=
ForwardMode
.
DECODE
self
.
sample_output
=
None
self
.
topk
:
int
=
server_args
.
speculative_eagle_topk
self
.
num_verify_token
:
int
=
server_args
.
speculative_num_draft_tokens
self
.
spec_steps
=
server_args
.
speculative_num_steps
self
.
scores
:
torch
.
Tensor
=
None
self
.
score_list
:
List
[
torch
.
Tensor
]
=
[]
...
...
@@ -200,11 +189,20 @@ class EAGLEDraftInput(SpecInfo):
self
.
parents_list
:
List
[
torch
.
Tensor
]
=
[]
self
.
cache_list
:
List
[
torch
.
Tenor
]
=
[]
self
.
iter
=
0
self
.
root_token
:
int
=
None
assert
self
.
topk
<=
10
,
"topk should <= 10"
self
.
hidden_states
:
torch
.
Tensor
=
None
self
.
verified_id
:
torch
.
Tensor
=
None
self
.
positions
:
torch
.
Tensor
=
None
self
.
accept_length
:
torch
.
Tensor
=
None
self
.
has_finished
:
bool
=
False
self
.
unfinished_index
:
List
[
int
]
=
None
def
load_server_args
(
self
,
server_args
:
ServerArgs
):
self
.
topk
:
int
=
server_args
.
speculative_eagle_topk
self
.
num_verify_token
:
int
=
server_args
.
speculative_num_draft_tokens
self
.
spec_steps
=
server_args
.
speculative_num_steps
def
prepare_for_extend
(
self
,
batch
:
Forward
Batch
):
def
prepare_for_extend
(
self
,
batch
:
Schedule
Batch
):
req_pool_indices
=
batch
.
alloc_req_slots
(
len
(
batch
.
reqs
))
out_cache_loc
=
batch
.
alloc_token_slots
(
batch
.
input_ids
.
numel
())
batch
.
out_cache_loc
=
out_cache_loc
...
...
@@ -226,81 +224,72 @@ class EAGLEDraftInput(SpecInfo):
pt
+=
req
.
extend_input_len
seq_lens
=
[
0
]
+
batch
.
extend_lens
input_ids
=
batch
.
input_ids
.
tolist
()
verified_id
=
batch
.
spec_info
.
verified_id
.
tolist
()
model_input_ids
=
[]
for
i
in
range
(
len
(
seq_lens
)
-
1
):
model_input_ids
.
extend
(
input_ids
[
seq_lens
[
i
]
+
1
:
seq_lens
[
i
+
1
]]
+
[
verified_id
[
i
]]
)
batch
.
input_ids
=
torch
.
tensor
(
model_input_ids
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
def
capture_for_decode
(
self
,
sample_output
:
SampleOutput
,
hidden_states
:
torch
.
Tensor
,
prev_mode
:
ForwardMode
,
):
self
.
sample_output
=
sample_output
self
.
prev_mode
=
prev_mode
self
.
hidden_states
=
hidden_states
# TODO: support batching inputs
assert
len
(
batch
.
extend_lens
)
==
1
batch
.
input_ids
=
torch
.
concat
((
batch
.
input_ids
[
1
:],
self
.
verified_id
))
def
prepare_for_decode
(
self
,
batch
:
ScheduleBatch
):
prob
=
self
.
sample_output
#
b * (1/topk)
, vocab
prob
=
self
.
sample_output
#
shape: (b * top_k, vocab) or (b
, vocab
)
top
=
torch
.
topk
(
prob
,
self
.
topk
,
dim
=-
1
)
topk_index
,
topk_p
=
top
.
indices
,
top
.
values
# b * (1/topk), topk
if
self
.
prev_mode
==
ForwardMode
.
DECODE
:
topk_index
,
topk_p
=
(
top
.
indices
,
top
.
values
,
)
# shape: (b * top_k, top_k) or (b, top_k)
if
self
.
prev_mode
.
is_decode
():
scores
=
torch
.
mul
(
self
.
scores
.
unsqueeze
(
2
),
topk_p
.
reshape
(
-
1
,
self
.
topk
,
self
.
topk
)
)
# (b, topk
) mul
(b
*
topk ,topk) -> b, topk, topk
)
# (b, topk
, 1) x
(b
,
topk ,topk) ->
(
b, topk, topk
)
topk_cs
=
torch
.
topk
(
scores
.
flatten
(
start_dim
=
1
),
self
.
topk
,
dim
=-
1
)
# (b, topk)
topk_cs_index
,
topk_cs_p
=
topk_cs
.
indices
,
topk_cs
.
values
self
.
scores
=
topk_cs_p
selected_input_index
=
topk_cs_index
.
flatten
()
//
self
.
topk
# b* topk
selected_input_index
=
(
topk_cs_index
.
flatten
()
//
self
.
topk
)
# shape: (b * topk)
batch
.
spec_info
.
hidden_states
=
batch
.
spec_info
.
hidden_states
[
selected_input_index
,
:
]
topk_index
=
topk_index
.
reshape
(
-
1
,
self
.
topk
**
2
)
batch
.
input_ids
=
torch
.
gather
(
topk_index
,
index
=
topk_cs_index
,
dim
=
1
).
flatten
()
batch
.
out_cache_loc
=
batch
.
alloc_token_slots
(
batch
.
input_ids
.
numel
())
self
.
score_list
.
append
(
scores
)
# b, topk, topk
self
.
token_list
.
append
(
topk_index
)
# b, topk*topk
batch
.
out_cache_loc
=
batch
.
alloc_token_slots
(
len
(
batch
.
input_ids
))
self
.
scores
=
topk_cs_p
self
.
score_list
.
append
(
scores
)
# (b, topk, topk)
self
.
token_list
.
append
(
topk_index
)
# (b, topk * topk)
self
.
origin_score_list
.
append
(
topk_p
.
reshape
(
topk_index
.
shape
))
self
.
parents_list
.
append
(
topk_cs_index
+
(
self
.
topk
**
2
*
(
self
.
iter
-
1
)
+
self
.
topk
)
)
# b, topk
elif
self
.
prev_mode
in
(
ForwardMode
.
EXTEND
,
ForwardMode
.
DRAFT_EXTEND
):
self
.
scores
=
topk_p
# b, top_k
self
.
score_list
.
append
(
topk_p
.
unsqueeze
(
1
))
self
.
token_list
.
append
(
topk_index
)
self
.
origin_score_list
.
append
(
topk_p
)
)
# shape: (b, topk)
else
:
# ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND
batch
.
spec_info
.
hidden_states
=
(
batch
.
spec_info
.
hidden_states
.
repeat_interleave
(
self
.
topk
,
0
)
batch
.
spec_info
.
hidden_states
.
repeat_interleave
(
self
.
topk
,
dim
=
0
)
)
batch
.
input_ids
=
topk_index
.
flatten
()
batch
.
out_cache_loc
=
batch
.
alloc_token_slots
(
topk_index
.
numel
())
self
.
scores
=
topk_p
# shape: (b, topk)
self
.
score_list
.
append
(
topk_p
.
unsqueeze
(
1
))
# shape: (b, 1, topk)
self
.
token_list
.
append
(
topk_index
)
# shape: (b, topk)
self
.
origin_score_list
.
append
(
topk_p
)
self
.
parents_list
.
append
(
torch
.
arange
(
-
1
,
self
.
topk
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
.
unsqueeze
(
0
)
.
repeat
(
self
.
scores
.
shape
[
0
],
1
)
)
# b, topk
+1
)
#
shape: (
b, topk
+ 1)
self
.
cache_list
.
append
(
batch
.
out_cache_loc
)
self
.
positions
=
(
batch
.
seq_lens
[:,
None
]
+
torch
.
ones
([
1
,
self
.
topk
],
device
=
"cuda"
,
dtype
=
torch
.
long
)
*
self
.
iter
).
flatten
()
bs
=
batch
.
seq_lens
.
numel
(
)
bs
=
len
(
batch
.
seq_lens
)
assign_req_to_token_pool
[(
bs
,)](
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
...
...
@@ -419,11 +408,6 @@ class EAGLEDraftInput(SpecInfo):
)
return
bs
,
kv_indices
,
cum_kv_seq_len
def
clear
(
self
):
self
.
iter
=
0
self
.
score_list
.
clear
()
self
.
positions
=
None
def
clear_draft_cache
(
self
,
batch
):
draft_cache
=
torch
.
cat
(
self
.
cache_list
,
dim
=
0
)
batch
.
token_to_kv_pool
.
free
(
draft_cache
)
...
...
@@ -460,7 +444,6 @@ class EAGLEDraftInput(SpecInfo):
[
self
.
hidden_states
,
spec_info
.
hidden_states
],
axis
=
0
)
self
.
verified_id
=
torch
.
cat
([
self
.
verified_id
,
spec_info
.
verified_id
],
axis
=
0
)
# self.positions = torch.cat([self.positions, spec_info.positions], axis=0)
self
.
sample_output
=
torch
.
cat
([
self
.
sample_output
,
spec_info
.
sample_output
])
...
...
@@ -568,9 +551,6 @@ class EagleVerifyInput(SpecInfo):
)
accept_index
=
accept_index
[
accept_index
!=
-
1
]
# extract_index = extract_index[extract_index != 0]
draft_input
=
EAGLEDraftInput
()
accept_length_cpu
=
accept_length
.
tolist
()
verified_id
=
predict
[
accept_index
]
...
...
@@ -596,6 +576,7 @@ class EagleVerifyInput(SpecInfo):
# retracted_reqs, new_token_ratio = batch.retract_decode()
low
=
0
draft_input
=
EAGLEDraftInput
()
for
i
,
(
req
,
verified_len
)
in
enumerate
(
zip
(
batch
.
reqs
,
accept_length_cpu
)):
req
.
output_ids
.
extend
(
verified_id_cpu
[
low
:
low
+
verified_len
+
1
])
req
.
check_finished
()
...
...
@@ -615,4 +596,10 @@ class EagleVerifyInput(SpecInfo):
draft_input
.
unfinished_index
=
unfinished_index
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[
accept_index
]
return
draft_input
,
logits_output
,
verified_id
,
finished_extend_len
return
(
draft_input
,
logits_output
,
verified_id
,
finished_extend_len
,
accept_length_cpu
,
)
python/sglang/srt/speculative/eagle_worker.py
View file @
b8574f69
...
...
@@ -51,63 +51,72 @@ class EAGLEWorker(TpModelWorker):
batch
.
spec_info
.
prepare_for_decode
(
batch
)
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
forward_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
def
forward_draft_extend
(
self
,
batch
:
ScheduleBatch
):
self
.
_s
wap
_mem_pool
(
batch
,
self
.
model_runner
)
self
.
_s
et
_mem_pool
(
batch
,
self
.
model_runner
)
batch
.
spec_info
.
prepare_for_extend
(
batch
)
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
forward_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
self
.
_s
wap
_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
self
.
_s
et
_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
def
forward_batch_speculative_generation
(
self
,
batch
:
ScheduleBatch
):
if
batch
.
forward_mode
.
is_decode
():
prev_spec_info
=
batch
.
spec_info
self
.
_s
wap
_mem_pool
(
batch
,
self
.
model_runner
)
# Draft
self
.
_s
et
_mem_pool
(
batch
,
self
.
model_runner
)
for
i
in
range
(
self
.
server_args
.
speculative_num_steps
):
self
.
forward_draft_decode
(
batch
)
batch
.
spec_info
.
clear_draft_cache
(
batch
)
self
.
_swap_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
self
.
_set_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
# Verify
(
next_draft_input
,
logits_output
,
verified_id
,
self
.
finish_extend_len
,
accept_length_cpu
,
model_worker_batch
,
)
=
self
.
verify
(
batch
)
next_draft_input
.
init
(
self
.
server_args
)
next_draft_input
.
load_server_args
(
self
.
server_args
)
batch
.
spec_info
=
next_draft_input
# if it is None, means all requsets are finished
if
batch
.
spec_info
.
verified_id
is
not
None
:
self
.
forward_extend_after_decode
(
batch
)
batch
.
spec_info
=
prev_spec_info
return
logits_output
,
verified_id
,
model_worker_batch
,
next_draft_input
self
.
forward_draft_extend_after_decode
(
batch
)
return
(
logits_output
,
verified_id
,
model_worker_batch
,
sum
(
accept_length_cpu
),
)
else
:
spec_info
=
EAGLEDraftInput
()
spec_info
.
init
(
self
.
server_args
)
# Forward with the target model and get hidden states.
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
.
spec_info
=
spec_info
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
model_worker_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
logits_output
,
next_token_ids
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
)
model_worker_batch
.
spec_info
.
verified_id
=
next_token_ids
model_worker_batch
.
spec_info
.
hidden_states
=
logits_output
.
hidden_states
# Forward with the draft model.
spec_info
=
EAGLEDraftInput
()
spec_info
.
load_server_args
(
self
.
server_args
)
spec_info
.
hidden_states
=
logits_output
.
hidden_states
spec_info
.
verified_id
=
next_token_ids
batch
.
spec_info
=
spec_info
self
.
forward_draft_extend
(
batch
)
batch
.
spec_info
=
None
return
logits_output
,
next_token_ids
,
model_worker_batch
,
spec_info
return
logits_output
,
next_token_ids
,
model_worker_batch
,
0
def
verify
(
self
,
batch
:
ScheduleBatch
):
verify_input
=
batch
.
spec_info
.
prepare_for_verify
(
batch
)
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
verify_input
.
prepare_for_verify
(
batch
)
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
batch
.
spec_info
=
verify_input
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
...
@@ -119,38 +128,41 @@ class EAGLEWorker(TpModelWorker):
batch
.
forward_mode
=
ForwardMode
.
DECODE
return
res
+
(
model_worker_batch
,)
def
_s
wap
_mem_pool
(
self
,
batch
:
ScheduleBatch
,
runner
:
ModelRunner
):
def
_s
et
_mem_pool
(
self
,
batch
:
ScheduleBatch
,
runner
:
ModelRunner
):
batch
.
token_to_kv_pool
=
runner
.
token_to_kv_pool
batch
.
req_to_token_pool
=
runner
.
req_to_token_pool
def
forward_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
self
.
_s
wap
_mem_pool
(
batch
,
self
.
model_runner
)
def
forward_
draft_
extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
self
.
_s
et
_mem_pool
(
batch
,
self
.
model_runner
)
batch
.
forward_mode
=
ForwardMode
.
DRAFT_EXTEND
if
batch
.
spec_info
.
has_finished
:
index
=
batch
.
spec_info
.
unfinished_index
seq_lens
=
batch
.
seq_lens
batch
.
seq_lens
=
batch
.
seq_lens
[
index
]
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
)
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
forward_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
batch
.
spec_info
.
hidden_states
=
logits_output
.
hidden_states
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
batch
.
forward_mode
=
ForwardMode
.
DECODE
if
batch
.
spec_info
.
has_finished
:
batch
.
seq_lens
=
seq_lens
self
.
_s
wap
_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
self
.
_s
et
_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
def
capture_for_decode
(
self
,
logits_output
,
forward_batch
):
if
isinstance
(
logits_output
,
LogitsProcessorOutput
):
logits
=
logits_output
.
next_token_logits
def
capture_for_decode
(
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
):
sample_output
=
torch
.
softmax
(
logits
,
dim
=-
1
)
# TODO: Support more sampling method @kavioyu
forward_batch
.
spec_info
.
capture_for_decode
(
sample_output
,
logits_output
.
hidden_states
,
forward_batch
.
forward_mode
)
logits_output
.
next_token_logits
,
dim
=-
1
)
# TODO(kavioyu): Support more sampling methods
spec_info
=
forward_batch
.
spec_info
spec_info
.
sample_output
=
sample_output
spec_info
.
hidden_states
=
logits_output
.
hidden_states
spec_info
.
prev_mode
=
forward_batch
.
forward_mode
# Don't support prefix share now.
def
finish_request
(
self
,
reqs
:
Union
[
Req
,
List
[
Req
]]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment