Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
54294854
"vscode:/vscode.git/clone" did not exist on "4570535ec41e9e6f808d4cd3a9a06c6928652dea"
Commit
54294854
authored
Apr 11, 2025
by
lizhigong
Browse files
add v0 zero overhead
parent
a0c212c0
Changes
13
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1350 additions
and
3 deletions
+1350
-3
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+6
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+8
-2
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+4
-0
vllm/profiler/prof.py
vllm/profiler/prof.py
+73
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+4
-0
vllm/zero_overhead/v0/llm_engine.py
vllm/zero_overhead/v0/llm_engine.py
+517
-0
vllm/zero_overhead/v0/model_runner.py
vllm/zero_overhead/v0/model_runner.py
+46
-0
vllm/zero_overhead/v0/sampler.py
vllm/zero_overhead/v0/sampler.py
+435
-0
vllm/zero_overhead/v0/sequence.py
vllm/zero_overhead/v0/sequence.py
+60
-0
vllm/zero_overhead/v0/stop_check.py
vllm/zero_overhead/v0/stop_check.py
+77
-0
vllm/zero_overhead/v0/tokenizer.py
vllm/zero_overhead/v0/tokenizer.py
+84
-0
vllm/zero_overhead/v0/update_input.py
vllm/zero_overhead/v0/update_input.py
+28
-0
vllm/zero_overhead/v0/utils.py
vllm/zero_overhead/v0/utils.py
+8
-0
No files found.
vllm/engine/multiprocessing/engine.py
View file @
54294854
...
...
@@ -6,6 +6,8 @@ from contextlib import contextmanager
from
typing
import
Iterator
,
List
,
Optional
,
Union
import
cloudpickle
from
vllm.zero_overhead.v0.llm_engine
import
ZeroOverheadEngine
from
vllm.zero_overhead.v0.utils
import
is_zero_overhead
import
zmq
from
vllm
import
AsyncEngineArgs
,
SamplingParams
...
...
@@ -79,6 +81,9 @@ class MQLLMEngine:
# the python object to be reused again.
kwargs
[
'use_cached_outputs'
]
=
True
if
is_zero_overhead
():
self
.
engine
=
ZeroOverheadEngine
(
*
args
,
**
kwargs
)
else
:
self
.
engine
=
LLMEngine
(
*
args
,
**
kwargs
)
self
.
log_requests
=
log_requests
...
...
vllm/entrypoints/llm.py
View file @
54294854
...
...
@@ -43,6 +43,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
Counter
,
Device
,
deprecate_args
,
deprecate_kwargs
,
is_list_of
)
from
vllm.zero_overhead.v0.llm_engine
import
ZeroOverheadEngine
from
vllm.zero_overhead.v0.utils
import
is_zero_overhead
logger
=
init_logger
(
__name__
)
...
...
@@ -244,6 +246,10 @@ class LLM:
)
# Create the Engine (autoselects V0 vs V1)
if
is_zero_overhead
():
self
.
llm_engine
=
ZeroOverheadEngine
.
from_engine_args
(
engine_args
=
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
else
:
self
.
llm_engine
=
LLMEngine
.
from_engine_args
(
engine_args
=
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
self
.
engine_class
=
type
(
self
.
llm_engine
)
...
...
vllm/model_executor/layers/sampler.py
View file @
54294854
...
...
@@ -21,6 +21,8 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput
,
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
SequenceOutput
)
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
from
vllm.zero_overhead.v0.sampler
import
ZeroOverheadSampler
from
vllm.zero_overhead.v0.utils
import
is_zero_overhead
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
import
flashinfer.sampling
...
...
@@ -38,6 +40,8 @@ def get_sampler() -> torch.nn.Module:
# Lazy import: the v1 package isn't distributed
from
vllm.v1.sample.sampler
import
Sampler
as
V1Sampler
return
V1Sampler
()
if
is_zero_overhead
():
return
ZeroOverheadSampler
()
return
Sampler
()
...
...
vllm/profiler/prof.py
0 → 100644
View file @
54294854
from
ctypes
import
*
import
os
import
time
import
threading
class
Prof
:
def
__init__
(
self
):
self
.
use_nvtx
=
os
.
getenv
(
'VLLM_PROF_NVTX'
)
is
not
None
self
.
roc_tracer_flag
=
False
self
.
lib
=
None
if
self
.
use_nvtx
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libnvToolsExt.so"
)
self
.
lib
.
nvtxRangePushA
.
argtypes
=
[
c_char_p
]
self
.
lib
.
nvtxRangePushA
.
restype
=
c_int
self
.
lib
.
nvtxRangePop
.
restype
=
c_int
self
.
use_roctx
=
os
.
getenv
(
'VLLM_PROF_ROCTX'
)
is
not
None
if
self
.
use_roctx
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libroctracer64.so"
)
self
.
lib
.
roctxRangePushA
.
argtypes
=
[
c_char_p
]
self
.
lib
.
roctxRangePushA
.
restype
=
c_int
self
.
lib
.
roctxRangePop
.
restype
=
c_int
self
.
tm
=
time
.
perf_counter
()
self
.
push_depth
=
{}
def
StartTracer
(
self
):
if
self
.
use_roctx
:
if
self
.
lib
is
None
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libroctracer64.so"
)
self
.
lib
.
roctracer_start
()
self
.
roc_tracer_flag
=
True
def
StopTracer
(
self
):
if
self
.
use_roctx
:
if
self
.
lib
is
None
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libroctracer64.so"
)
self
.
lib
.
roctracer_stop
()
self
.
roc_tracer_flag
=
False
def
thread_depth_add
(
self
,
num
):
current_thread
=
threading
.
current_thread
()
thread_id
=
current_thread
.
ident
if
thread_id
not
in
self
.
push_depth
.
keys
():
self
.
push_depth
[
thread_id
]
=
0
if
num
<
0
and
self
.
push_depth
[
thread_id
]
==
0
:
return
False
self
.
push_depth
[
thread_id
]
+=
num
return
True
def
ProfRangePush
(
self
,
message
):
if
profile
.
use_nvtx
:
profile
.
lib
.
nvtxRangePushA
(
message
.
encode
(
'utf-8'
))
self
.
thread_depth_add
(
1
)
if
profile
.
use_roctx
and
self
.
roc_tracer_flag
:
profile
.
lib
.
roctxRangePushA
(
message
.
encode
(
'utf-8'
))
self
.
thread_depth_add
(
1
)
def
ProfRangePop
(
self
):
if
profile
.
use_nvtx
:
if
not
self
.
thread_depth_add
(
-
1
):
return
profile
.
lib
.
nvtxRangePop
()
if
profile
.
use_roctx
and
self
.
roc_tracer_flag
:
if
not
self
.
thread_depth_add
(
-
1
):
return
profile
.
lib
.
roctxRangePop
()
def
ProfRangeAutoPush
(
self
,
message
):
self
.
ProfRangePop
()
self
.
ProfRangePush
(
message
)
profile
=
Prof
()
vllm/worker/model_runner.py
View file @
54294854
...
...
@@ -60,6 +60,8 @@ from vllm.worker.model_runner_base import (
_add_sampling_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_sampling_metadata_from_tensor_dict
)
from
vllm.zero_overhead.v0.model_runner
import
ZeroOverheadModelInputForGpuBuilder
from
vllm.zero_overhead.v0.utils
import
is_zero_overhead
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
...
...
@@ -1636,6 +1638,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
_model_input_cls
:
Type
[
ModelInputForGPUWithSamplingMetadata
]
=
(
ModelInputForGPUWithSamplingMetadata
)
_builder_cls
:
Type
[
ModelInputForGPUBuilder
]
=
ModelInputForGPUBuilder
if
is_zero_overhead
():
_builder_cls
=
ZeroOverheadModelInputForGpuBuilder
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
...
...
vllm/zero_overhead/v0/llm_engine.py
0 → 100644
View file @
54294854
This diff is collapsed.
Click to expand it.
vllm/zero_overhead/v0/model_runner.py
0 → 100644
View file @
54294854
import
torch
import
itertools
from
typing
import
List
,
Optional
,
Set
from
vllm.lora.layers
import
LoRAMapping
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.prompt_adapter.layers
import
PromptAdapterMapping
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
async_tensor_h2d
,
flatten_2d_lists
from
vllm.worker.model_runner
import
ModelInputForGPU
,
ModelInputForGPUBuilder
from
vllm.zero_overhead.v0.sampler
import
get_last_sampler
from
vllm.zero_overhead.v0.update_input
import
UpdateInputTokens
class
ZeroOverheadModelInputForGpuBuilder
(
ModelInputForGPUBuilder
):
def
__init__
(
self
,
runner
,
finished_requests_ids
=
None
):
super
().
__init__
(
runner
,
finished_requests_ids
)
self
.
req_ids
=
[]
def
prepare
(
self
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
self
.
req_ids
.
clear
()
return
super
().
prepare
(
finished_requests_ids
)
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
seq_ids
=
seq_group_metadata
.
seq_data
.
keys
()
n_seqs
=
len
(
seq_ids
)
seq_ids
=
list
(
seq_ids
)
for
seq_idx
in
range
(
n_seqs
):
self
.
req_ids
.
append
(
seq_ids
[
seq_idx
])
return
super
().
add_seq_group
(
seq_group_metadata
)
def
build
(
self
)
->
ModelInputForGPU
:
model_input
=
super
().
build
()
last_sampler
=
get_last_sampler
()
if
last_sampler
.
sampled_token_ids_tensor
is
not
None
:
input_ids
=
async_tensor_h2d
(
self
.
req_ids
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
last_ids
=
async_tensor_h2d
(
last_sampler
.
seq_id
.
tolist
(),
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
UpdateInputTokens
(
model_input
.
input_tokens
,
input_ids
,
last_sampler
.
sampled_token_ids_tensor
,
last_ids
)
return
model_input
vllm/zero_overhead/v0/sampler.py
0 → 100644
View file @
54294854
from
importlib.util
import
find_spec
from
typing
import
Dict
,
List
,
Optional
import
torch
from
vllm
import
envs
from
vllm.model_executor.layers.rejection_sampler
import
_multinomial
from
vllm.model_executor.layers.sampler
import
MultinomialSamplesType
,
SampleMetadataType
,
\
SampleResultArgsType
,
SampleResultType
,
SampleResultsDictType
,
SampleReturnType
,
Sampler
,
\
SamplerOutput
,
_apply_min_p
,
_apply_min_tokens_penalty
,
_apply_top_k_top_p
,
_build_sampler_output
,
\
_modify_greedy_probs_inplace
,
_top_k_top_p_multinomial_with_flashinfer
,
get_logprobs
from
vllm.model_executor.layers.utils
import
apply_penalties
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
,
SamplingTensors
,
SequenceGroupToSample
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
VLLM_INVALID_TOKEN_ID
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
import
flashinfer.sampling
# yapf: disable
from
flashinfer.sampling
import
(
top_k_top_p_sampling_from_probs
as
flashinfer_top_k_top_p_sampling
)
class
SampleRecorder
:
def
__init__
(
self
):
self
.
seq_id
:
torch
.
Tensor
=
None
self
.
sampled_token_ids_tensor
:
torch
.
Tensor
=
None
last_sampler
=
SampleRecorder
()
def
get_last_sampler
():
return
last_sampler
class
ZeroOverheadSampler
(
Sampler
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
"""
Single-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor
Multi-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs
tensor
* Encapsulate arguments required for deferred Pythonization
in the :class:`SamplerOutput` structure
Args:
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
assert
logits
is
not
None
_
,
vocab_size
=
logits
.
shape
# Prepare sampling tensors with pinned memory to avoid blocking.
if
not
sampling_metadata
.
reuse_sampling_tensors
:
self
.
_init_sampling_tensors
(
logits
,
sampling_metadata
)
elif
self
.
_do_penalties
:
# In this case, the sampling tensors logic depends on
# "output_tokens" of a sequence. As a result, we cannot
# reuse sampling tensors, since "output_tokens" changes
# between decode runs.
self
.
_init_sampling_tensors
(
logits
,
sampling_metadata
)
assert
self
.
_sampling_tensors
is
not
None
sampling_tensors
=
self
.
_sampling_tensors
do_penalties
=
self
.
_do_penalties
do_top_p_top_k
=
self
.
_do_top_p_top_k
do_min_p
=
self
.
_do_min_p
logits
=
_apply_min_tokens_penalty
(
logits
,
sampling_metadata
)
# Apply presence and frequency penalties.
if
do_penalties
:
logits
=
apply_penalties
(
logits
,
sampling_tensors
.
prompt_tokens
,
sampling_tensors
.
output_tokens
,
sampling_tensors
.
presence_penalties
,
sampling_tensors
.
frequency_penalties
,
sampling_tensors
.
repetition_penalties
)
# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits
=
logits
.
to
(
torch
.
float
)
logits
.
div_
(
sampling_tensors
.
temperatures
.
unsqueeze
(
dim
=
1
))
if
do_top_p_top_k
and
flashinfer_top_k_top_p_sampling
is
None
:
logits
=
_apply_top_k_top_p
(
logits
,
sampling_tensors
.
top_ps
,
sampling_tensors
.
top_ks
)
if
do_min_p
:
logits
=
_apply_min_p
(
logits
,
sampling_tensors
.
min_ps
)
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Compute the log probabilities.
logprobs
=
torch
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Sample the next tokens.
maybe_deferred_sample_results
,
maybe_sampled_tokens_tensor
=
_sample
(
probs
,
logprobs
,
sampling_metadata
,
sampling_tensors
,
include_gpu_probs_tensor
=
self
.
include_gpu_probs_tensor
,
modify_greedy_probs
=
self
.
_should_modify_greedy_probs_inplace
,
)
if
self
.
include_gpu_probs_tensor
:
# Since we will defer sampler result Pythonization,
# preserve GPU-side tensors in support of later
# deferred pythonization of logprobs
assert
maybe_sampled_tokens_tensor
is
not
None
on_device_tensors
=
(
probs
,
logprobs
,
maybe_sampled_tokens_tensor
)
else
:
# Since Pythonization has already happened, don't preserve
# GPU-side tensors.
on_device_tensors
=
None
# Get the logprobs query results.
prompt_logprobs
=
None
sample_logprobs
=
None
if
not
sampling_metadata
.
skip_sampler_cpu_output
:
# Pythonize logprobs now (GPU -> CPU); do not defer.
assert
not
isinstance
(
maybe_deferred_sample_results
,
SampleResultArgsType
)
prompt_logprobs
,
sample_logprobs
=
get_logprobs
(
logprobs
,
sampling_metadata
,
maybe_deferred_sample_results
)
return
_build_sampler_output
(
maybe_deferred_sample_results
,
sampling_metadata
,
prompt_logprobs
,
sample_logprobs
,
on_device_tensors
=
on_device_tensors
,
skip_sampler_cpu_output
=
sampling_metadata
.
skip_sampler_cpu_output
,
logits
=
logits
)
def
_greedy_sample
(
selected_seq_groups
:
List
[
SequenceGroupToSample
],
samples
:
torch
.
Tensor
,
)
->
SampleResultType
:
"""Run greedy sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
samples: (num_selected_samples,) A tensor of samples. The length of
samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
sample_idx
=
0
results
:
SampleResultType
=
[]
for
seq_group
in
selected_seq_groups
:
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
continue
seq_ids
=
seq_group
.
seq_ids
num_parent_seqs
=
len
(
seq_ids
)
assert
num_parent_seqs
==
1
,
(
"Greedy sampling should have only one seq."
)
parent_ids
=
list
(
range
(
num_parent_seqs
))
assert
num_parent_seqs
==
1
# not support muti seqences in seqence group
next_token_ids
=
[
0
]
#place holder token id
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
return
results
def
_random_sample
(
selected_seq_groups
:
List
[
SequenceGroupToSample
],
random_samples
:
torch
.
Tensor
,
)
->
SampleResultType
:
"""Run random sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
random_samples: (num_selected_samples,) A tensor of samples. The
length of samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum n value of the prompt phase requests.
sample_idx
=
0
results
:
SampleResultType
=
[]
for
seq_group
in
selected_seq_groups
:
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
continue
seq_ids
=
seq_group
.
seq_ids
sampling_params
=
seq_group
.
sampling_params
is_prompt
=
seq_group
.
is_prompt
num_parent_seqs
=
len
(
seq_ids
)
if
is_prompt
:
# Prompt phase.
parent_ids
=
[
0
]
*
sampling_params
.
n
assert
num_parent_seqs
==
1
# not support muti seqences in seqence group
next_token_ids
=
[
0
]
*
sampling_params
.
n
#place holder token id
else
:
# Generation phase.
parent_ids
=
list
(
range
(
num_parent_seqs
))
assert
num_parent_seqs
==
1
# not support muti seqences in seqence group
next_token_ids
=
[
0
]
*
num_parent_seqs
#place holder token id
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
return
results
def
_sample
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
)
->
SampleReturnType
:
"""
Args:
probs: (num_query_tokens_in_batch, num_vocab)
logprobs: (num_query_tokens_in_batch, num_vocab)
sampling_metadata: The metadata for a batch for sampling.
sampling_tensors: Tensors that include sampling related metadata.
Returns:
(next_token_ids, parent_seq_ids) for each seq group in a batch.
If sampling is skipped, it returns ([], [])
sampled_token_ids_tensor: A tensor of sampled token ids.
"""
return
_sample_with_torch
(
probs
,
logprobs
,
sampling_metadata
,
sampling_tensors
,
include_gpu_probs_tensor
=
include_gpu_probs_tensor
,
modify_greedy_probs
=
modify_greedy_probs
,
)
def
_sample_with_torch
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
)
->
SampleReturnType
:
'''Torch-oriented _sample() implementation.
Single-step scheduling:
* Perform GPU-side sampling computation
* Immediately Pythonize sampling result
Multi-step scheduling:
* Perform GPU-side sampling computation
* Defer Pythonization & preserve GPU-side
tensors required for Pythonization
'''
categorized_seq_group_ids
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
for
t
in
SamplingType
}
last_sampler
.
seq_id
=
torch
.
zeros
(
len
(
sampling_metadata
.
seq_groups
),
dtype
=
torch
.
int32
)
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
last_sampler
.
seq_id
[
i
]
=
seq_group
.
seq_ids
[
0
]
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
sample_results_dict
:
SampleResultsDictType
=
{}
sample_metadata
:
SampleMetadataType
=
{}
multinomial_samples
:
MultinomialSamplesType
=
{}
greedy_samples
:
Optional
[
torch
.
Tensor
]
=
None
beam_search_logprobs
:
Optional
[
torch
.
Tensor
]
=
None
# Create output tensor for sampled token ids.
if
include_gpu_probs_tensor
:
sampled_token_ids_tensor
=
torch
.
full
((
logprobs
.
shape
[
0
],
1
),
VLLM_INVALID_TOKEN_ID
,
dtype
=
torch
.
long
,
device
=
logprobs
.
device
)
else
:
sampled_token_ids_tensor
=
None
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for
sampling_type
in
SamplingType
:
sample_indices
=
categorized_sample_indices
[
sampling_type
]
num_tokens
=
len
(
sample_indices
)
if
num_tokens
==
0
:
continue
seq_group_id
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_id
]
sample_metadata
[
sampling_type
]
=
(
seq_group_id
,
seq_groups
)
long_sample_indices
=
sample_indices
.
long
()
if
sampling_type
==
SamplingType
.
GREEDY
:
greedy_samples
=
torch
.
argmax
(
logprobs
[
long_sample_indices
],
dim
=-
1
)
last_sampler
.
sampled_token_ids_tensor
=
greedy_samples
.
unsqueeze
(
-
1
)
if
sampled_token_ids_tensor
is
not
None
:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor
[
long_sample_indices
]
=
greedy_samples
.
unsqueeze
(
-
1
)
if
modify_greedy_probs
:
# If required, modify the probabilities such that sampling from
# the modified distribution would always sample the argmax
# token id.
_modify_greedy_probs_inplace
(
logprobs
,
probs
,
long_sample_indices
,
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
max_n_in_batch
=
1
for
seq_group
in
seq_groups
:
if
seq_group
.
is_prompt
:
sampling_params
=
seq_group
.
sampling_params
max_n_in_batch
=
max
(
max_n_in_batch
,
sampling_params
.
n
)
seq_groups_arg
=
(
None
if
sampling_type
==
SamplingType
.
RANDOM
else
seq_groups
)
if
flashinfer_top_k_top_p_sampling
is
not
None
:
multinomial_samples
[
sampling_type
]
=
_top_k_top_p_multinomial_with_flashinfer
(
probs
[
long_sample_indices
],
sampling_tensors
.
top_ks
[
long_sample_indices
],
sampling_tensors
.
top_ps
[
long_sample_indices
],
max_n_in_batch
,
seq_groups_arg
,
)
else
:
multinomial_samples
[
sampling_type
]
=
_multinomial
(
probs
[
long_sample_indices
],
max_n_in_batch
,
seq_groups
=
seq_groups_arg
)
last_sampler
.
sampled_token_ids_tensor
=
\
multinomial_samples
[
sampling_type
].
to
(
torch
.
long
)
if
sampled_token_ids_tensor
is
not
None
:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor
[
long_sample_indices
]
=
\
multinomial_samples
[
sampling_type
].
to
(
torch
.
long
)
elif
sampling_type
==
SamplingType
.
BEAM
:
beam_search_logprobs
=
logprobs
[
sample_indices
]
else
:
raise
ValueError
(
f
"Unsupported sampling type:
{
sampling_type
}
"
)
# Encapsulate arguments for computing Pythonized sampler
# results, whether deferred or otherwise.
maybe_deferred_args
=
SampleResultArgsType
(
sampling_metadata
=
sampling_metadata
,
sample_metadata
=
sample_metadata
,
multinomial_samples
=
multinomial_samples
,
greedy_samples
=
greedy_samples
,
beam_search_logprobs
=
beam_search_logprobs
,
sample_results_dict
=
sample_results_dict
)
if
not
sampling_metadata
.
skip_sampler_cpu_output
:
# GPU<->CPU sync happens here.
# This also converts the sampler output to a Python object.
# Return Pythonized sampler result & sampled token ids
return
get_pythonized_sample_results
(
maybe_deferred_args
),
sampled_token_ids_tensor
else
:
# Defer sampler result Pythonization; return deferred
# Pythonization args & sampled token ids
return
(
maybe_deferred_args
,
sampled_token_ids_tensor
,
)
def
get_pythonized_sample_results
(
sample_result_args
:
SampleResultArgsType
)
->
SampleResultType
:
'''This function consumes GPU-side sampler results and computes
Pythonized CPU-side sampler results (GPU -> CPU sync.)
Single-step scheduling: this function is invoked at sampling-time
for immediate Pythonization.
Multi-step scheduling: Pythonization is deferred until after multiple
GPU-side steps have been completed.
Args:
sample_result_args: GPU-side inputs to the Pythonization process
Returns:
Pythonized sampler results
'''
(
sample_metadata
,
sampling_metadata
,
greedy_samples
,
multinomial_samples
,
sample_results_dict
,
)
=
(
sample_result_args
.
sample_metadata
,
sample_result_args
.
sampling_metadata
,
sample_result_args
.
greedy_samples
,
sample_result_args
.
multinomial_samples
,
sample_result_args
.
sample_results_dict
,
)
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
continue
(
seq_group_id
,
seq_groups
)
=
sample_metadata
[
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
sample_results
=
_random_sample
(
seq_groups
,
multinomial_samples
[
sampling_type
])
sample_results_dict
.
update
(
zip
(
seq_group_id
,
sample_results
))
return
[
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
\ No newline at end of file
vllm/zero_overhead/v0/sequence.py
0 → 100644
View file @
54294854
from
typing
import
Union
from
vllm.sequence
import
Sequence
from
typing
import
Sequence
as
GenericSequence
class
ZeroOverheadSequence
(
Sequence
):
def
__init__
(
self
,
seq_id
,
inputs
,
block_size
,
eos_token_id
=
None
,
lora_request
=
None
,
prompt_adapter_request
=
None
):
super
().
__init__
(
seq_id
,
inputs
,
block_size
,
eos_token_id
,
lora_request
,
prompt_adapter_request
)
self
.
effective_output_len
:
int
=
0
def
fix_last_token_id
(
self
,
token_id
:
int
)
->
None
:
effect_offset
=
self
.
effective_output_len
-
len
(
self
.
data
.
output_token_ids
)
assert
effect_offset
<
0
self
.
data
.
_output_token_ids
[
effect_offset
]
=
token_id
if
len
(
self
.
data
.
_new_appended_tokens
)
>=
effect_offset
*
-
1
:
self
.
data
.
_new_appended_tokens
[
effect_offset
]
=
token_id
self
.
data
.
_cached_all_token_ids
[
effect_offset
]
=
token_id
self
.
effective_output_len
+=
1
def
zero_overhead_get_output_token_ids
(
self
)
->
tuple
[
int
,
...]:
return
self
.
data
.
output_token_ids
[:
self
.
effective_output_len
]
def
zero_overhead_get_output_len
(
self
)
->
int
:
return
self
.
effective_output_len
def
zero_overhead_get_last_token_id
(
self
)
->
int
:
if
self
.
effective_output_len
==
0
:
return
self
.
data
.
_prompt_token_ids
[
-
1
]
return
self
.
data
.
_output_token_ids
[
self
.
effective_output_len
-
1
]
def
zero_overhead_get_len
(
self
)
->
int
:
return
self
.
effective_output_len
+
len
(
self
.
data
.
_prompt_token_ids
)
def
get_output_token_ids_to_return
(
self
,
delta
:
bool
)
->
Union
[
GenericSequence
[
int
],
int
]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if
not
delta
:
return
self
.
zero_overhead_get_output_token_ids
()
output_len
=
self
.
zero_overhead_get_output_len
()
# Get the number of new tokens
num_new_tokens
=
output_len
-
self
.
_last_output_token_ids_offset
self
.
_last_output_token_ids_offset
=
output_len
# Return new tokens
if
num_new_tokens
==
1
:
# Optimization for single decode token case
# (which is what we have most of the time)
return
self
.
data
.
_cached_all_token_ids
[
self
.
effective_output_len
-
1
]
if
num_new_tokens
==
0
:
return
[]
effect_offset
=
self
.
effective_output_len
-
len
(
self
.
data
.
output_token_ids
)
return
self
.
data
.
_cached_all_token_ids
[
-
num_new_tokens
:
effect_offset
]
\ No newline at end of file
vllm/zero_overhead/v0/stop_check.py
0 → 100644
View file @
54294854
from
typing
import
Optional
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SequenceStatus
from
vllm.zero_overhead.v0.sequence
import
ZeroOverheadSequence
class
ZeroOverheadStopChecker
(
StopChecker
):
def
__init__
(
self
,
max_model_len
,
get_tokenizer_for_seq
):
super
().
__init__
(
max_model_len
,
get_tokenizer_for_seq
)
def
maybe_stop_sequence
(
self
,
seq
:
ZeroOverheadSequence
,
new_char_count
:
int
,
sampling_params
:
SamplingParams
,
lora_req
:
Optional
[
LoRARequest
]
=
None
,
)
->
None
:
"""Stop the finished sequences.
new_char_count is the number of chars added to the
sequence's output text for the newly generated token
"""
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if
seq
.
zero_overhead_get_output_len
()
<
sampling_params
.
min_tokens
:
return
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
zero_overhead_get_last_token_id
()
==
seq
.
eos_token_id
):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id
=
seq
.
get_last_token_id
(
self
.
zero_overhead
)
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
# Remove last token
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
last_token_id
return
# Check if any stop strings are matched.
stop
=
self
.
check_stop_strings
(
seq
.
output_text
,
new_char_count
,
sampling_params
.
stop
,
sampling_params
.
include_stop_str_in_output
)
if
stop
is
not
None
:
stop_str
,
truncate_to
=
stop
if
truncate_to
!=
-
1
:
seq
.
output_text
=
seq
.
output_text
[:
truncate_to
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
stop_str
return
# Check if the sequence has reached max_model_len.
if
seq
.
zero_overhead_get_len
()
>
self
.
_get_max_model_len
(
lora_req
):
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if
seq
.
zero_overhead_get_output_len
()
==
sampling_params
.
max_tokens
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
\ No newline at end of file
vllm/zero_overhead/v0/tokenizer.py
0 → 100644
View file @
54294854
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
VLLM_INVALID_TOKEN_ID
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer_utils
import
convert_prompt_ids_to_tokens
,
detokenize_incrementally
from
vllm.zero_overhead.v0.sequence
import
ZeroOverheadSequence
class
ZeroOverheadDetokenizer
(
Detokenizer
):
def
__init__
(
self
,
tokenizer_group
):
super
().
__init__
(
tokenizer_group
)
def
decode_sequence_inplace
(
self
,
seq
:
ZeroOverheadSequence
,
prms
:
SamplingParams
)
->
int
:
"""Decodes the new token for a sequence. In-place operation.
Args:
seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence.
Returns:
The number of characters added to the output text.
"""
eff_length
=
seq
.
get_prompt_len
()
+
seq
.
effective_output_len
all_input_ids
=
seq
.
get_token_ids
()[
:
eff_length
]
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
# computation for each logprob.
if
seq
.
tokens
is
None
:
(
seq
.
tokens
,
seq
.
prefix_offset
,
seq
.
read_offset
)
=
convert_prompt_ids_to_tokens
(
tokenizer
=
tokenizer
,
prompt_ids
=
all_input_ids
[:
-
1
],
skip_special_tokens
=
prms
.
skip_special_tokens
,
)
(
new_tokens
,
new_decoded_token_text
,
prefix_offset
,
read_offset
)
=
detokenize_incrementally
(
tokenizer
=
tokenizer
,
all_input_ids
=
all_input_ids
,
prev_tokens
=
seq
.
tokens
,
prefix_offset
=
seq
.
prefix_offset
,
read_offset
=
seq
.
read_offset
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
)
# Decode logprobs
logprobs
=
seq
.
output_logprobs
[
-
1
]
if
logprobs
:
previous_tokens
=
all_input_ids
[:
-
1
]
for
token_id
,
sample_logprob
in
logprobs
.
items
():
# If the token was generated this iteration,
# use the provided text.
if
token_id
==
token_id_generated_this_iteration
:
sample_logprob
.
decoded_token
=
new_decoded_token_text
continue
if
(
sample_logprob
.
decoded_token
is
None
and
token_id
!=
VLLM_INVALID_TOKEN_ID
):
all_input_ids_with_logprob
=
previous_tokens
+
[
token_id
]
(
_
,
new_text
,
_
,
_
)
=
detokenize_incrementally
(
tokenizer
=
tokenizer
,
all_input_ids
=
all_input_ids_with_logprob
,
prev_tokens
=
seq
.
tokens
,
prefix_offset
=
seq
.
prefix_offset
,
read_offset
=
seq
.
read_offset
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
)
sample_logprob
.
decoded_token
=
new_text
seq
.
tokens
.
extend
(
new_tokens
)
seq
.
prefix_offset
=
prefix_offset
seq
.
read_offset
=
read_offset
seq
.
output_text
+=
new_decoded_token_text
return
len
(
new_decoded_token_text
)
\ No newline at end of file
vllm/zero_overhead/v0/update_input.py
0 → 100644
View file @
54294854
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_update_input_tokens
(
sample_output
,
seq_ids
,
input_tokens
,
input_seq_ids
,
BATCH_SIZE1
,
BATCH_SIZE2
,
):
pid
=
tl
.
program_id
(
0
)
if
pid
>=
BATCH_SIZE2
:
return
output_token
=
tl
.
load
(
input_tokens
+
pid
)
_input_seq_id
=
tl
.
load
(
input_seq_ids
+
pid
)
for
i
in
range
(
BATCH_SIZE1
):
_seq_ids
=
tl
.
load
(
seq_ids
+
i
)
if
_seq_ids
==
_input_seq_id
:
output_token
=
tl
.
load
(
sample_output
+
i
)
tl
.
store
(
input_tokens
+
pid
,
output_token
)
def
UpdateInputTokens
(
input_tokens
,
input_seq_ids
,
last_sample
,
last_ids
):
grid
=
[
input_seq_ids
.
shape
[
0
],
1
,
1
]
_update_input_tokens
[
grid
](
last_sample
,
last_ids
,
input_tokens
,
input_seq_ids
,
last_ids
.
shape
[
0
],
input_seq_ids
.
shape
[
0
])
\ No newline at end of file
vllm/zero_overhead/v0/utils.py
0 → 100644
View file @
54294854
import
os
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
def
is_zero_overhead
():
return
zero_overhead
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment