Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a495fc3b
Commit
a495fc3b
authored
Jul 10, 2025
by
zhuwenwen
Browse files
fix zero overhead to support chunk prefill
parent
fe1c4016
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
88 additions
and
14 deletions
+88
-14
vllm/zero_overhead/llm_engine.py
vllm/zero_overhead/llm_engine.py
+8
-4
vllm/zero_overhead/model_runner.py
vllm/zero_overhead/model_runner.py
+5
-3
vllm/zero_overhead/sampler.py
vllm/zero_overhead/sampler.py
+75
-7
No files found.
vllm/zero_overhead/llm_engine.py
View file @
a495fc3b
...
@@ -299,7 +299,10 @@ class ZeroOverheadEngine(LLMEngine):
...
@@ -299,7 +299,10 @@ class ZeroOverheadEngine(LLMEngine):
last_sampler
=
self
.
last_record
[
1
]
last_sampler
=
self
.
last_record
[
1
]
spec_step
=
get_spec_step
()
spec_step
=
get_spec_step
()
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
self
.
async_d2h
=
last_sampler
.
sampled_token_ids_tensor
.
to
(
'cpu'
,
non_blocking
=
True
)
if
last_sampler
.
sampled_token_ids_tensor
is
not
None
:
self
.
async_d2h
=
last_sampler
.
sampled_token_ids_tensor
.
to
(
'cpu'
,
non_blocking
=
True
)
else
:
self
.
async_d2h
=
None
elif
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
elif
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
self
.
async_d2h
=
last_sampler
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
async_d2h
=
last_sampler
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
async_event
.
record
()
self
.
async_event
.
record
()
...
@@ -367,9 +370,10 @@ class ZeroOverheadEngine(LLMEngine):
...
@@ -367,9 +370,10 @@ class ZeroOverheadEngine(LLMEngine):
ctx
.
scheduler_outputs
=
scheduler_outputs
ctx
.
scheduler_outputs
=
scheduler_outputs
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
self
.
async_event
.
synchronize
()
self
.
async_event
.
synchronize
()
self
.
_fix_last_step
(
if
self
.
async_d2h
is
not
None
:
outputs
,
last_sampler
,
seq_group_metadata_list
,
self
.
_fix_last_step
(
scheduler_outputs
.
scheduled_seq_groups
)
outputs
,
last_sampler
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
elif
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
elif
spec_step
==
SpecStepKind
.
SCORE_DECODE
:
self
.
async_event
.
synchronize
()
self
.
async_event
.
synchronize
()
self
.
_fix_spec_decode_steps
(
self
.
_fix_spec_decode_steps
(
...
...
vllm/zero_overhead/model_runner.py
View file @
a495fc3b
...
@@ -99,13 +99,15 @@ class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
...
@@ -99,13 +99,15 @@ class ZeroOverheadModelInputForGpuBuilder(ModelInputForGPUBuilder):
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
if
spec_step
==
SpecStepKind
.
KIND_DEFAULT
:
update_indices
=
[]
update_indices
=
[]
select_indices
=
[]
select_indices
=
[]
query_idx
=
0
for
i
,
seq_id
in
enumerate
(
self
.
req_ids
):
for
i
,
seq_id
in
enumerate
(
self
.
req_ids
):
for
j
,
seq_id_
in
enumerate
(
last_sampler
.
seq_ids
):
for
j
,
seq_id_
in
enumerate
(
last_sampler
.
seq_ids
):
if
seq_id
==
seq_id_
:
if
seq_id
==
seq_id_
:
select_indices
.
append
(
j
)
select_indices
.
append
(
j
)
update_indices
.
append
(
i
)
update_indices
.
append
(
query_idx
)
break
break
if
len
(
select_indices
)
>
0
:
query_idx
+=
model_input
.
query_lens
[
i
]
if
len
(
select_indices
)
>
0
and
last_sampler
.
sampled_token_ids_tensor
is
not
None
:
select_indices
=
async_tensor_h2d
(
select_indices
,
torch
.
long
,
select_indices
=
async_tensor_h2d
(
select_indices
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
self
.
runner
.
pin_memory
)
...
...
vllm/zero_overhead/sampler.py
View file @
a495fc3b
from
importlib.util
import
find_spec
from
importlib.util
import
find_spec
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
vllm
import
envs
from
vllm
import
envs
from
vllm.model_executor.layers.sampler
import
MultinomialSamplesType
,
SampleMetadataType
,
\
from
vllm.distributed.parallel_state
import
get_tp_group
from
vllm.model_executor.layers.sampler
import
MaybeDeferredSampleResultType
,
MultinomialSamplesType
,
SampleMetadataType
,
\
SampleResultArgsType
,
SampleResultType
,
SampleResultsDictType
,
SampleReturnType
,
Sampler
,
\
SampleResultArgsType
,
SampleResultType
,
SampleResultsDictType
,
SampleReturnType
,
Sampler
,
\
SamplerOutput
,
_apply_min_p
,
_apply_min_tokens_penalty
,
_apply_top_k_top_p
,
_build_sampler_output
,
\
SamplerOutput
,
_apply_min_p
,
_apply_min_tokens_penalty
,
_apply_top_k_top_p
,
\
_modify_greedy_probs_inplace
,
_top_k_top_p_multinomial_with_flashinfer
,
get_logprobs
,
_multinomial
_modify_greedy_probs_inplace
,
_top_k_top_p_multinomial_with_flashinfer
,
get_logprobs
,
_multinomial
from
vllm.model_executor.layers.utils
import
apply_penalties
from
vllm.model_executor.layers.utils
import
apply_penalties
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
,
SamplingTensors
,
SequenceGroupToSample
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
,
SamplingTensors
,
SequenceGroupToSample
from
vllm.sampling_params
import
SamplingType
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
VLLM_INVALID_TOKEN_ID
from
vllm.sequence
import
VLLM_INVALID_TOKEN_ID
,
CompletionSequenceGroupOutput
,
PromptLogprobs
,
SampleLogprobs
,
SequenceOutput
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
import
flashinfer.sampling
import
flashinfer.sampling
# yapf: disable
# yapf: disable
...
@@ -275,10 +276,8 @@ def _sample_with_torch(
...
@@ -275,10 +276,8 @@ def _sample_with_torch(
t
:
[]
t
:
[]
for
t
in
SamplingType
for
t
in
SamplingType
}
}
last_sampler
.
seq_ids
=
[]
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
last_sampler
.
seq_ids
.
append
(
seq_group
.
seq_ids
[
0
])
sampling_params
=
seq_group
.
sampling_params
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
...
@@ -429,4 +428,73 @@ def get_pythonized_sample_results(
...
@@ -429,4 +428,73 @@ def get_pythonized_sample_results(
return
[
return
[
sample_results_dict
.
get
(
i
,
([],
[]))
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
]
\ No newline at end of file
def
_build_sampler_output
(
maybe_deferred_sample_results
:
MaybeDeferredSampleResultType
,
sampling_metadata
:
SamplingMetadata
,
prompt_logprobs
:
Optional
[
List
[
Optional
[
PromptLogprobs
]]],
sample_logprobs
:
Optional
[
List
[
SampleLogprobs
]],
on_device_tensors
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]],
skip_sampler_cpu_output
:
bool
=
False
,
logits
:
Optional
[
torch
.
Tensor
]
=
None
)
->
SamplerOutput
:
"""Construct Python objects with the output of sampling.
Args:
on_device_tensors: Tuple containing on-device tensors with the
probabilities used in sampling and the sampled token ids. This
allows post-processing without copies to CPU/serialization, e.g. in
speculative decoding rejection sampling.
"""
sampler_output
:
List
[
CompletionSequenceGroupOutput
]
=
[]
last_sampler
.
seq_ids
=
[]
if
skip_sampler_cpu_output
:
assert
isinstance
(
maybe_deferred_sample_results
,
SampleResultArgsType
)
deferred_sample_results_args
=
maybe_deferred_sample_results
else
:
assert
prompt_logprobs
is
not
None
assert
sample_logprobs
is
not
None
assert
not
isinstance
(
maybe_deferred_sample_results
,
SampleResultArgsType
)
assert
len
(
sampling_metadata
.
seq_groups
)
\
==
len
(
maybe_deferred_sample_results
)
\
==
len
(
prompt_logprobs
)
\
==
len
(
sample_logprobs
)
deferred_sample_results_args
=
None
for
(
seq_group
,
sample_result
,
group_prompt_logprobs
,
group_sample_logprobs
)
in
zip
(
sampling_metadata
.
seq_groups
,
maybe_deferred_sample_results
,
prompt_logprobs
,
sample_logprobs
):
seq_ids
=
seq_group
.
seq_ids
next_token_ids
,
parent_ids
=
sample_result
seq_outputs
:
List
[
SequenceOutput
]
=
[]
for
parent_id
,
next_token_id
,
logprobs
in
zip
(
parent_ids
,
next_token_ids
,
group_sample_logprobs
):
seq_outputs
.
append
(
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
logprobs
))
sampler_output
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
group_prompt_logprobs
))
if
len
(
seq_outputs
)
>
0
:
last_sampler
.
seq_ids
.
append
(
seq_outputs
[
0
].
parent_seq_id
)
# If not specified, store None values in SamplerOutput.
if
on_device_tensors
is
not
None
:
(
sampled_token_probs
,
logprobs_tensor
,
sampled_token_ids
)
=
on_device_tensors
else
:
sampled_token_probs
,
logprobs_tensor
,
sampled_token_ids
=
(
None
,
None
,
None
)
return
SamplerOutput
(
outputs
=
sampler_output
,
sampled_token_probs
=
sampled_token_probs
,
sampled_token_ids
=
sampled_token_ids
,
logprobs
=
logprobs_tensor
,
deferred_sample_results_args
=
deferred_sample_results_args
,
logits
=
logits
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment