Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8a413453
Commit
8a413453
authored
Dec 30, 2025
by
jujl1
Browse files
feat: 兼容MTP零消耗和主模型+MTP零消耗(VLLM_ZERO_OVERHEAD_ENHANCE=1)开启
parent
5208b291
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
118 additions
and
47 deletions
+118
-47
vllm/envs.py
vllm/envs.py
+4
-0
vllm/zero_overhead/v1/core.py
vllm/zero_overhead/v1/core.py
+22
-7
vllm/zero_overhead/v1/gpu_model_runner.py
vllm/zero_overhead/v1/gpu_model_runner.py
+92
-40
No files found.
vllm/envs.py
View file @
8a413453
...
@@ -200,6 +200,7 @@ if TYPE_CHECKING:
...
@@ -200,6 +200,7 @@ if TYPE_CHECKING:
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER
:
bool
=
False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER
:
bool
=
False
VLLM_USE_FUSED_FILL_RMS_CAT
:
bool
=
False
VLLM_USE_FUSED_FILL_RMS_CAT
:
bool
=
False
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
:
bool
=
True
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
:
bool
=
True
VLLM_ZERO_OVERHEAD_ENHANCE
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
return
os
.
getenv
(
return
os
.
getenv
(
...
@@ -1298,6 +1299,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1298,6 +1299,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM"
:
"VLLM_ENABLE_DEEPEP_HT_DEEPGEMM"
:
lambda
:
(
os
.
getenv
(
'VLLM_ENABLE_DEEPEP_HT_DEEPGEMM'
,
'1'
).
lower
()
in
lambda
:
(
os
.
getenv
(
'VLLM_ENABLE_DEEPEP_HT_DEEPGEMM'
,
'1'
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
"VLLM_ZERO_OVERHEAD_ENHANCE"
:
lambda
:
(
os
.
getenv
(
'VLLM_ZERO_OVERHEAD_ENHANCE'
,
'0'
).
lower
()
in
(
"true"
,
"1"
)),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/zero_overhead/v1/core.py
View file @
8a413453
...
@@ -8,7 +8,7 @@ from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
...
@@ -8,7 +8,7 @@ from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
from
vllm.zero_overhead.v1.outputs
import
ZeroV1ModelRunnerOutput
from
vllm.zero_overhead.v1.outputs
import
ZeroV1ModelRunnerOutput
from
vllm
import
envs
requsets_valid_token_len
=
{}
requsets_valid_token_len
=
{}
def
check_stop
(
request
:
Request
,
def
check_stop
(
request
:
Request
,
...
@@ -83,10 +83,22 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
...
@@ -83,10 +83,22 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
request
.
_all_token_ids
[
fix_offset
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
1
requsets_valid_token_len
[
req_id
]
+=
1
generated_token_ids
=
[
generated_token_ids
]
generated_token_ids
=
[
generated_token_ids
]
el
se
:
el
if
envs
.
VLLM_ZERO_OVERHEAD_ENHANCE
:
requsets_valid_token_len
[
req_id
]
+=
len
(
generated_token_ids
)
requsets_valid_token_len
[
req_id
]
+=
len
(
generated_token_ids
)
request
.
_output_token_ids
[
fix_offset
:
]
=
generated_token_ids
request
.
_output_token_ids
[
fix_offset
:
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
]
=
generated_token_ids
num_tokens_scheduled
=
num_scheduled_tokens
.
get
(
req_id
,
0
)
if
num_tokens_scheduled
==
0
:
request
.
num_computed_tokens
=
request
.
num_tokens
-
1
else
:
valid_output_end
=
valid_output_len
+
len
(
generated_token_ids
)
-
request
.
num_output_tokens
if
valid_output_end
==
0
:
request
.
_output_token_ids
[
fix_offset
:
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
]
=
generated_token_ids
else
:
request
.
_output_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
len
(
generated_token_ids
)
stopped
=
False
stopped
=
False
...
@@ -189,8 +201,9 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
...
@@ -189,8 +201,9 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
req_index
=
model_runner_output
.
req_id_to_index
[
req_id
]
req_index
=
model_runner_output
.
req_id_to_index
[
req_id
]
generated_token_ids
=
sampled_token_ids
[
generated_token_ids
=
sampled_token_ids
[
req_index
]
if
sampled_token_ids
else
[]
req_index
]
if
sampled_token_ids
else
[]
if
request
.
num_computed_tokens
==
request
.
num_prompt_tokens
:
if
(
envs
.
VLLM_ZERO_OVERHEAD_ENHANCE
and
request
.
num_computed_tokens
==
request
.
num_prompt_tokens
):
generated_token_ids
=
generated_token_ids
[:
1
]
generated_token_ids
=
generated_token_ids
[:
1
]
scheduled_spec_token_ids
=
(
scheduled_spec_token_ids
=
(
...
@@ -203,8 +216,10 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
...
@@ -203,8 +216,10 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# tokens, where is given by:
# tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_new
=
len
(
generated_token_ids
)
num_new
=
len
(
generated_token_ids
)
if
(
model_runner_output
.
fix_req_ids
and
req_id
in
model_runner_output
.
fix_req_ids
if
(
envs
.
VLLM_ZERO_OVERHEAD_ENHANCE
and
and
request
.
num_computed_tokens
>
request
.
num_prompt_tokens
+
num_new
):
model_runner_output
.
fix_req_ids
and
req_id
in
model_runner_output
.
fix_req_ids
and
request
.
num_computed_tokens
>
request
.
num_prompt_tokens
+
num_new
):
req_idx
=
model_runner_output
.
fix_req_ids
.
index
(
req_id
)
req_idx
=
model_runner_output
.
fix_req_ids
.
index
(
req_id
)
num_new
=
len
(
model_runner_output
.
fix_sampled_token_ids
[
req_idx
])
num_new
=
len
(
model_runner_output
.
fix_sampled_token_ids
[
req_idx
])
num_tokens_rejected
=
(
len
(
scheduled_spec_token_ids
)
+
1
-
num_new
)
num_tokens_rejected
=
(
len
(
scheduled_spec_token_ids
)
+
1
-
num_new
)
...
@@ -213,7 +228,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
...
@@ -213,7 +228,7 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
spec_decoding_stats
=
scheduler
.
make_spec_decoding_stats
(
spec_decoding_stats
=
scheduler
.
make_spec_decoding_stats
(
spec_decoding_stats
,
spec_decoding_stats
,
num_draft_tokens
=
len
(
scheduled_spec_token_ids
),
num_draft_tokens
=
len
(
scheduled_spec_token_ids
),
num_accepted_tokens
=
(
num_new
-
1
)
if
generated_token_ids
else
0
)
num_accepted_tokens
=
num_new
-
1
)
# NOTE(woosuk): This has to be executed after updating
# NOTE(woosuk): This has to be executed after updating
# `request.num_computed_tokens`.
# `request.num_computed_tokens`.
...
...
vllm/zero_overhead/v1/gpu_model_runner.py
View file @
8a413453
...
@@ -84,7 +84,8 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -84,7 +84,8 @@ class V1ZeroModelRunner(GPUModelRunner):
num_scheduled_tokens
)
num_scheduled_tokens
)
if
self
.
speculative_config
and
self
.
last_sampler_host_tokens
!=
None
:
if
(
envs
.
VLLM_ZERO_OVERHEAD_ENHANCE
and
self
.
speculative_config
and
self
.
last_sampler_host_tokens
!=
None
):
self
.
fix_req_ids
=
self
.
last_sampled_req_ids
self
.
fix_req_ids
=
self
.
last_sampled_req_ids
self
.
last_sampler_event
.
synchronize
()
# 等上一轮主模型结束
self
.
last_sampler_event
.
synchronize
()
# 等上一轮主模型结束
num_gen_tokens
=
self
.
last_sampler_host_tokens
.
shape
[
-
1
]
num_gen_tokens
=
self
.
last_sampler_host_tokens
.
shape
[
-
1
]
...
@@ -106,8 +107,8 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -106,8 +107,8 @@ class V1ZeroModelRunner(GPUModelRunner):
# # 更新token统计数据
# # 更新token统计数据
self
.
input_batch
.
num_tokens_no_spec
[
new_req_idx
]
=
new_end_idx
self
.
input_batch
.
num_tokens_no_spec
[
new_req_idx
]
=
new_end_idx
self
.
input_batch
.
num_tokens
[
new_req_idx
]
=
new_end_idx
self
.
input_batch
.
num_tokens
[
new_req_idx
]
=
new_end_idx
self
.
input_batch
.
token_ids_cpu
[
new_req_idx
,
start_idx
:
new_end_idx
]
=
self
.
fix_sampled_token_ids
[
self
.
input_batch
.
token_ids_cpu
[
new_req_idx
,
start_idx
:
new_end_idx
]
=
(
req_idx
]
self
.
fix_sampled_token_ids
)[
req_idx
]
self
.
input_batch
.
num_computed_tokens_cpu
[
new_req_idx
]
-=
(
end_idx
-
new_end_idx
)
self
.
input_batch
.
num_computed_tokens_cpu
[
new_req_idx
]
-=
(
end_idx
-
new_end_idx
)
if
req_id
in
self
.
requests
:
if
req_id
in
self
.
requests
:
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
...
@@ -299,13 +300,15 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -299,13 +300,15 @@ class V1ZeroModelRunner(GPUModelRunner):
True
)
True
)
last_draft_token_ids
=
self
.
last_draft_token_ids
.
flatten
().
to
(
torch
.
int
)
last_draft_token_ids
=
self
.
last_draft_token_ids
.
flatten
().
to
(
torch
.
int
)
input_ids
[
input_ids_indices_tensor
]
=
last_draft_token_ids
[
update_req_indices_tensor
]
input_ids
[
input_ids_indices_tensor
]
=
last_draft_token_ids
[
update_req_indices_tensor
]
update_req_indices
=
[]
update_req_indices
=
[]
input_ids_indices
=
[]
input_ids_indices
=
[]
token_idx
=
0
token_idx
=
0
if
self
.
last_sampled_token_ids
is
not
None
:
if
self
.
last_sampled_token_ids
is
not
None
:
sampled_tokens_num
=
1
if
self
.
speculative_config
else
self
.
last_sampled_token_ids
.
shape
[
1
]
for
req_id
in
req_ids
:
for
req_id
in
req_ids
:
if
req_id
in
self
.
last_sampled_req_ids
:
if
req_id
in
self
.
last_sampled_req_ids
:
req_idx
=
self
.
last_sampled_req_ids
.
index
(
req_id
)
req_idx
=
self
.
last_sampled_req_ids
.
index
(
req_id
)
*
sampled_tokens_num
update_req_indices
.
append
(
req_idx
)
update_req_indices
.
append
(
req_idx
)
input_ids_indices
.
append
(
token_idx
)
input_ids_indices
.
append
(
token_idx
)
token_idx
+=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
token_idx
+=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
...
@@ -316,12 +319,14 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -316,12 +319,14 @@ class V1ZeroModelRunner(GPUModelRunner):
input_ids_indices_tensor
=
async_tensor_h2d
(
input_ids_indices
,
torch
.
int32
,
input_ids_indices_tensor
=
async_tensor_h2d
(
input_ids_indices
,
torch
.
int32
,
self
.
device
,
self
.
device
,
True
)
True
)
if
self
.
speculative_config
:
if
envs
.
VLLM_ZERO_OVERHEAD_ENHANCE
and
self
.
speculative_config
:
fused_update_input_ids_impl
(
self
.
last_sampled_token_ids
,
input_ids
,
fused_update_input_ids_impl
(
self
.
last_sampled_token_ids
,
input_ids
,
update_req_indices_tensor
,
input_ids_indices_tensor
)
update_req_indices_tensor
,
input_ids_indices_tensor
)
else
:
else
:
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
flatten
()
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
flatten
()
input_ids
[
input_ids_indices_tensor
]
=
last_sampled_token_ids
[
update_req_indices_tensor
]
for
i
in
range
(
sampled_tokens_num
):
input_ids
[
input_ids_indices_tensor
+
i
]
=
(
last_sampled_token_ids
)[
update_req_indices_tensor
+
i
]
def
propose_draft_token_ids
(
def
propose_draft_token_ids
(
self
,
self
,
...
@@ -698,43 +703,44 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -698,43 +703,44 @@ class V1ZeroModelRunner(GPUModelRunner):
is_output_valid
=
False
is_output_valid
=
False
# Get the valid generated tokens.
# Get the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
sampled_token_ids
=
sampler_output
.
sampled_token_ids
if
not
self
.
speculative_config
:
over_head_enhance
=
envs
.
VLLM_ZERO_OVERHEAD_ENHANCE
and
self
.
speculative_config
self
.
fix_req_ids
=
self
.
last_sampled_req_ids
if
over_head_enhance
:
if
self
.
last_sampler_host_tokens
is
not
None
:
# if not self.speculative_config:
self
.
last_sampler_event
.
synchronize
()
# self.fix_req_ids = self.last_sampled_req_ids
self
.
fix_sampled_token_ids
=
self
.
last_sampler_host_tokens
.
tolist
()
# if self.last_sampler_host_tokens is not None:
for
req_idx
,
start_idx
,
end_idx
in
self
.
token_ids_cpu_fix_record
:
# self.last_sampler_event.synchronize()
if
start_idx
==
-
1
:
# self.fix_sampled_token_ids = self.last_sampler_host_tokens.tolist()
self
.
fix_sampled_token_ids
[
req_idx
].
clear
()
# for req_idx, start_idx, end_idx in self.token_ids_cpu_fix_record:
continue
# if start_idx == -1:
req_id
=
self
.
fix_req_ids
[
req_idx
]
# self.fix_sampled_token_ids[req_idx].clear()
if
req_id
in
self
.
input_batch
.
req_ids
:
# continue
new_req_idx
=
self
.
input_batch
.
req_ids
.
index
(
req_id
)
# req_id = self.fix_req_ids[req_idx]
self
.
input_batch
.
token_ids_cpu
[
new_req_idx
,
start_idx
:
end_idx
]
=
self
.
fix_sampled_token_ids
[
req_idx
]
# if req_id in self.input_batch.req_ids:
for
req_idx
,
req_id
in
enumerate
(
self
.
fix_req_ids
):
# new_req_idx = self.input_batch.req_ids.index(req_id)
if
req_id
in
self
.
requests
:
# self.input_batch.token_ids_cpu[new_req_idx, start_idx:end_idx] = self.fix_sampled_token_ids[req_idx]
req_state
=
self
.
requests
[
req_id
]
# for req_idx, req_id in enumerate(self.fix_req_ids):
token_idx
=
self
.
last_sampled_token_lens
[
req_idx
]
# if req_id in self.requests:
if
token_idx
==
-
1
:
# req_state = self.requests[req_id]
continue
# token_idx = self.last_sampled_token_lens[req_idx]
fix_len
=
len
(
self
.
fix_sampled_token_ids
[
req_idx
])
# if token_idx == -1:
req_state
.
output_token_ids
[
token_idx
:
token_idx
+
fix_len
]
=
self
.
fix_sampled_token_ids
[
req_idx
]
# continue
self
.
last_sampler_host_tokens
=
None
# fix_len = len(self.fix_sampled_token_ids[req_idx])
self
.
last_sampled_token_ids
=
None
# req_state.output_token_ids[token_idx:token_idx + fix_len] = self.fix_sampled_token_ids[req_idx]
self
.
last_sampler_host_tokens
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
last_sampler_host_tokens
=
None
self
.
last_sampler_event
.
record
()
self
.
last_sampled_token_ids
=
None
self
.
last_sampled_token_ids
=
sampled_token_ids
self
.
last_sampler_host_tokens
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
valid_sampled_token_ids
=
np
.
ones
(
sampled_token_ids
.
shape
,
dtype
=
int
).
tolist
()
self
.
last_sampler_event
.
record
()
self
.
last_sampled_token_ids
=
sampled_token_ids
# Mask out the sampled tokens that should not be sampled.
valid_sampled_token_ids
=
np
.
ones
(
sampled_token_ids
.
shape
,
dtype
=
int
).
tolist
()
for
i
in
discard_sampled_tokens_req_indices
:
valid_sampled_token_ids
[
i
].
clear
()
if
not
self
.
speculative_config
:
if
not
self
.
speculative_config
:
# Speculative decoding is not enabled.
# Speculative decoding is not enabled.
spec_token_ids
=
None
spec_token_ids
=
None
fix_draft_req_ids
=
None
fix_draft_req_ids
=
None
else
:
else
:
if
not
over_head_enhance
:
sampled_token_ids_cpu
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
spec_sampler_event
.
record
()
if
self
.
last_draft_host_tokens
is
not
None
:
if
self
.
last_draft_host_tokens
is
not
None
:
self
.
last_draft_event
.
synchronize
()
self
.
last_draft_event
.
synchronize
()
fix_draft_token_ids
=
self
.
last_draft_host_tokens
.
tolist
()
fix_draft_token_ids
=
self
.
last_draft_host_tokens
.
tolist
()
...
@@ -755,6 +761,51 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -755,6 +761,51 @@ class V1ZeroModelRunner(GPUModelRunner):
attn_metadata
,
attn_metadata
,
)
)
if
not
over_head_enhance
:
if
self
.
speculative_config
:
self
.
spec_sampler_event
.
synchronize
()
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
if
max_gen_len
==
1
:
valid_sampled_token_ids
=
sampled_token_ids_cpu
.
tolist
()
else
:
# Includes spec decode tokens.
valid_sampled_token_ids
=
self
.
rejection_sampler
.
parse_output
(
sampled_token_ids_cpu
,
self
.
input_batch
.
vocab_size
,
)
self
.
last_sampler_host_tokens
=
None
self
.
last_sampled_token_ids
=
None
is_output_valid
=
True
else
:
# No spec decode tokens.
self
.
fix_req_ids
=
self
.
last_sampled_req_ids
if
self
.
last_sampler_host_tokens
!=
None
:
self
.
last_sampler_event
.
synchronize
()
self
.
fix_sampled_token_ids
=
self
.
last_sampler_host_tokens
.
tolist
()
for
req_idx
,
start_idx
,
end_idx
in
self
.
token_ids_cpu_fix_record
:
if
start_idx
==
-
1
:
continue
req_id
=
self
.
fix_req_ids
[
req_idx
]
if
req_id
in
self
.
input_batch
.
req_ids
:
new_req_idx
=
self
.
input_batch
.
req_ids
.
index
(
req_id
)
self
.
input_batch
.
token_ids_cpu
[
new_req_idx
,
start_idx
:
end_idx
]
=
self
.
fix_sampled_token_ids
[
req_idx
]
for
req_idx
,
req_id
in
enumerate
(
self
.
fix_req_ids
):
if
req_id
in
self
.
requests
:
req_state
=
self
.
requests
[
req_id
]
token_idx
=
self
.
last_sampled_token_lens
[
req_idx
]
if
token_idx
==
-
1
:
continue
fix_len
=
len
(
self
.
fix_sampled_token_ids
[
req_idx
])
req_state
.
output_token_ids
[
token_idx
:
token_idx
+
fix_len
]
=
self
.
fix_sampled_token_ids
[
req_idx
]
self
.
last_sampler_host_tokens
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
last_sampler_event
.
record
()
self
.
last_sampled_token_ids
=
sampled_token_ids
valid_sampled_token_ids
=
np
.
ones
(
sampled_token_ids
.
shape
,
dtype
=
int
).
tolist
()
# Mask out the sampled tokens that should not be sampled.
for
i
in
discard_sampled_tokens_req_indices
:
valid_sampled_token_ids
[
i
].
clear
()
# Cache the sampled tokens in the model runner, so that the scheduler
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
...
@@ -765,12 +816,13 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -765,12 +816,13 @@ class V1ZeroModelRunner(GPUModelRunner):
self
.
last_sampled_token_lens
=
[]
self
.
last_sampled_token_lens
=
[]
for
req_idx
,
sampled_ids
in
enumerate
(
valid_sampled_token_ids
):
for
req_idx
,
sampled_ids
in
enumerate
(
valid_sampled_token_ids
):
req_id
=
self
.
input_batch
.
req_ids
[
req_idx
]
req_id
=
self
.
input_batch
.
req_ids
[
req_idx
]
cache_output_len
=
-
1
self
.
last_sampled_req_ids
.
append
(
req_id
)
self
.
last_sampled_req_ids
.
append
(
req_id
)
cache_output_len
=
-
1
if
not
sampled_ids
:
if
not
sampled_ids
:
self
.
last_sampled_token_lens
.
append
(
-
1
)
self
.
last_sampled_token_lens
.
append
(
-
1
)
self
.
token_ids_cpu_fix_record
.
append
([
req_idx
,
-
1
,
-
1
])
self
.
token_ids_cpu_fix_record
.
append
([
req_idx
,
-
1
,
-
1
])
continue
continue
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
end_idx
=
start_idx
+
len
(
sampled_ids
)
end_idx
=
start_idx
+
len
(
sampled_ids
)
assert
end_idx
<=
self
.
max_model_len
,
(
assert
end_idx
<=
self
.
max_model_len
,
(
...
@@ -783,11 +835,11 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -783,11 +835,11 @@ class V1ZeroModelRunner(GPUModelRunner):
self
.
token_ids_cpu_fix_record
.
append
([
req_idx
,
start_idx
,
end_idx
])
self
.
token_ids_cpu_fix_record
.
append
([
req_idx
,
start_idx
,
end_idx
])
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
if
not
self
.
speculative_config
and
req_id
in
self
.
requests
:
if
not
over_head_enhance
and
req_id
in
self
.
requests
:
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
cache_output_len
=
len
(
req_state
.
output_token_ids
)
cache_output_len
=
len
(
req_state
.
output_token_ids
)
req_state
.
output_token_ids
.
extend
(
sampled_ids
)
req_state
.
output_token_ids
.
extend
(
sampled_ids
)
self
.
last_sampled_token_lens
.
append
(
cache_output_len
)
self
.
last_sampled_token_lens
.
append
(
cache_output_len
)
# Clear KVConnector state after all KVs are generated.
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment