Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
54294854
Commit
54294854
authored
Apr 11, 2025
by
lizhigong
Browse files
add v0 zero overhead
parent
a0c212c0
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1350 additions
and
3 deletions
+1350
-3
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+6
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+8
-2
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+4
-0
vllm/profiler/prof.py
vllm/profiler/prof.py
+73
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+4
-0
vllm/zero_overhead/v0/llm_engine.py
vllm/zero_overhead/v0/llm_engine.py
+517
-0
vllm/zero_overhead/v0/model_runner.py
vllm/zero_overhead/v0/model_runner.py
+46
-0
vllm/zero_overhead/v0/sampler.py
vllm/zero_overhead/v0/sampler.py
+435
-0
vllm/zero_overhead/v0/sequence.py
vllm/zero_overhead/v0/sequence.py
+60
-0
vllm/zero_overhead/v0/stop_check.py
vllm/zero_overhead/v0/stop_check.py
+77
-0
vllm/zero_overhead/v0/tokenizer.py
vllm/zero_overhead/v0/tokenizer.py
+84
-0
vllm/zero_overhead/v0/update_input.py
vllm/zero_overhead/v0/update_input.py
+28
-0
vllm/zero_overhead/v0/utils.py
vllm/zero_overhead/v0/utils.py
+8
-0
No files found.
vllm/engine/multiprocessing/engine.py
View file @
54294854
...
@@ -6,6 +6,8 @@ from contextlib import contextmanager
...
@@ -6,6 +6,8 @@ from contextlib import contextmanager
from
typing
import
Iterator
,
List
,
Optional
,
Union
from
typing
import
Iterator
,
List
,
Optional
,
Union
import
cloudpickle
import
cloudpickle
from
vllm.zero_overhead.v0.llm_engine
import
ZeroOverheadEngine
from
vllm.zero_overhead.v0.utils
import
is_zero_overhead
import
zmq
import
zmq
from
vllm
import
AsyncEngineArgs
,
SamplingParams
from
vllm
import
AsyncEngineArgs
,
SamplingParams
...
@@ -79,6 +81,9 @@ class MQLLMEngine:
...
@@ -79,6 +81,9 @@ class MQLLMEngine:
# the python object to be reused again.
# the python object to be reused again.
kwargs
[
'use_cached_outputs'
]
=
True
kwargs
[
'use_cached_outputs'
]
=
True
if
is_zero_overhead
():
self
.
engine
=
ZeroOverheadEngine
(
*
args
,
**
kwargs
)
else
:
self
.
engine
=
LLMEngine
(
*
args
,
**
kwargs
)
self
.
engine
=
LLMEngine
(
*
args
,
**
kwargs
)
self
.
log_requests
=
log_requests
self
.
log_requests
=
log_requests
...
...
vllm/entrypoints/llm.py
View file @
54294854
...
@@ -43,6 +43,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
...
@@ -43,6 +43,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
Counter
,
Device
,
deprecate_args
,
deprecate_kwargs
,
from
vllm.utils
import
(
Counter
,
Device
,
deprecate_args
,
deprecate_kwargs
,
is_list_of
)
is_list_of
)
from
vllm.zero_overhead.v0.llm_engine
import
ZeroOverheadEngine
from
vllm.zero_overhead.v0.utils
import
is_zero_overhead
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -244,6 +246,10 @@ class LLM:
...
@@ -244,6 +246,10 @@ class LLM:
)
)
# Create the Engine (autoselects V0 vs V1)
# Create the Engine (autoselects V0 vs V1)
if
is_zero_overhead
():
self
.
llm_engine
=
ZeroOverheadEngine
.
from_engine_args
(
engine_args
=
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
else
:
self
.
llm_engine
=
LLMEngine
.
from_engine_args
(
self
.
llm_engine
=
LLMEngine
.
from_engine_args
(
engine_args
=
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
engine_args
=
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
self
.
engine_class
=
type
(
self
.
llm_engine
)
self
.
engine_class
=
type
(
self
.
llm_engine
)
...
...
vllm/model_executor/layers/sampler.py
View file @
54294854
...
@@ -21,6 +21,8 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
...
@@ -21,6 +21,8 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput
,
Logprob
,
CompletionSequenceGroupOutput
,
Logprob
,
PromptLogprobs
,
SampleLogprobs
,
SequenceOutput
)
PromptLogprobs
,
SampleLogprobs
,
SequenceOutput
)
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
from
vllm.zero_overhead.v0.sampler
import
ZeroOverheadSampler
from
vllm.zero_overhead.v0.utils
import
is_zero_overhead
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
import
flashinfer.sampling
import
flashinfer.sampling
...
@@ -38,6 +40,8 @@ def get_sampler() -> torch.nn.Module:
...
@@ -38,6 +40,8 @@ def get_sampler() -> torch.nn.Module:
# Lazy import: the v1 package isn't distributed
# Lazy import: the v1 package isn't distributed
from
vllm.v1.sample.sampler
import
Sampler
as
V1Sampler
from
vllm.v1.sample.sampler
import
Sampler
as
V1Sampler
return
V1Sampler
()
return
V1Sampler
()
if
is_zero_overhead
():
return
ZeroOverheadSampler
()
return
Sampler
()
return
Sampler
()
...
...
vllm/profiler/prof.py
0 → 100644
View file @
54294854
from
ctypes
import
*
import
os
import
time
import
threading
class
Prof
:
def
__init__
(
self
):
self
.
use_nvtx
=
os
.
getenv
(
'VLLM_PROF_NVTX'
)
is
not
None
self
.
roc_tracer_flag
=
False
self
.
lib
=
None
if
self
.
use_nvtx
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libnvToolsExt.so"
)
self
.
lib
.
nvtxRangePushA
.
argtypes
=
[
c_char_p
]
self
.
lib
.
nvtxRangePushA
.
restype
=
c_int
self
.
lib
.
nvtxRangePop
.
restype
=
c_int
self
.
use_roctx
=
os
.
getenv
(
'VLLM_PROF_ROCTX'
)
is
not
None
if
self
.
use_roctx
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libroctracer64.so"
)
self
.
lib
.
roctxRangePushA
.
argtypes
=
[
c_char_p
]
self
.
lib
.
roctxRangePushA
.
restype
=
c_int
self
.
lib
.
roctxRangePop
.
restype
=
c_int
self
.
tm
=
time
.
perf_counter
()
self
.
push_depth
=
{}
def
StartTracer
(
self
):
if
self
.
use_roctx
:
if
self
.
lib
is
None
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libroctracer64.so"
)
self
.
lib
.
roctracer_start
()
self
.
roc_tracer_flag
=
True
def
StopTracer
(
self
):
if
self
.
use_roctx
:
if
self
.
lib
is
None
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libroctracer64.so"
)
self
.
lib
.
roctracer_stop
()
self
.
roc_tracer_flag
=
False
def
thread_depth_add
(
self
,
num
):
current_thread
=
threading
.
current_thread
()
thread_id
=
current_thread
.
ident
if
thread_id
not
in
self
.
push_depth
.
keys
():
self
.
push_depth
[
thread_id
]
=
0
if
num
<
0
and
self
.
push_depth
[
thread_id
]
==
0
:
return
False
self
.
push_depth
[
thread_id
]
+=
num
return
True
def
ProfRangePush
(
self
,
message
):
if
profile
.
use_nvtx
:
profile
.
lib
.
nvtxRangePushA
(
message
.
encode
(
'utf-8'
))
self
.
thread_depth_add
(
1
)
if
profile
.
use_roctx
and
self
.
roc_tracer_flag
:
profile
.
lib
.
roctxRangePushA
(
message
.
encode
(
'utf-8'
))
self
.
thread_depth_add
(
1
)
def
ProfRangePop
(
self
):
if
profile
.
use_nvtx
:
if
not
self
.
thread_depth_add
(
-
1
):
return
profile
.
lib
.
nvtxRangePop
()
if
profile
.
use_roctx
and
self
.
roc_tracer_flag
:
if
not
self
.
thread_depth_add
(
-
1
):
return
profile
.
lib
.
roctxRangePop
()
def
ProfRangeAutoPush
(
self
,
message
):
self
.
ProfRangePop
()
self
.
ProfRangePush
(
message
)
profile
=
Prof
()
vllm/worker/model_runner.py
View file @
54294854
...
@@ -60,6 +60,8 @@ from vllm.worker.model_runner_base import (
...
@@ -60,6 +60,8 @@ from vllm.worker.model_runner_base import (
_add_sampling_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_sampling_metadata_from_tensor_dict
)
_init_sampling_metadata_from_tensor_dict
)
from
vllm.zero_overhead.v0.model_runner
import
ZeroOverheadModelInputForGpuBuilder
from
vllm.zero_overhead.v0.utils
import
is_zero_overhead
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
...
@@ -1636,6 +1638,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1636,6 +1638,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
_model_input_cls
:
Type
[
ModelInputForGPUWithSamplingMetadata
]
=
(
_model_input_cls
:
Type
[
ModelInputForGPUWithSamplingMetadata
]
=
(
ModelInputForGPUWithSamplingMetadata
)
ModelInputForGPUWithSamplingMetadata
)
_builder_cls
:
Type
[
ModelInputForGPUBuilder
]
=
ModelInputForGPUBuilder
_builder_cls
:
Type
[
ModelInputForGPUBuilder
]
=
ModelInputForGPUBuilder
if
is_zero_overhead
():
_builder_cls
=
ZeroOverheadModelInputForGpuBuilder
def
make_model_input_from_broadcasted_tensor_dict
(
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
self
,
...
...
vllm/zero_overhead/v0/llm_engine.py
0 → 100644
View file @
54294854
from
collections
import
Counter
from
functools
import
partial
import
os
import
queue
import
threading
import
traceback
from
typing
import
Callable
,
Dict
,
List
,
Mapping
,
Optional
,
Type
,
Union
from
zlib
import
ZLIB_VERSION
import
torch
from
vllm
import
envs
from
vllm.config
import
DecodingConfig
,
ObservabilityConfig
,
VllmConfig
from
vllm.core.scheduler
import
ScheduledSequenceGroup
from
vllm.engine.llm_engine
import
_LOCAL_LOGGING_INTERVAL_SEC
,
LLMEngine
,
SchedulerContext
,
SchedulerOutputState
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.engine.output_processor.interfaces
import
SequenceGroupOutputProcessor
from
vllm.entrypoints
import
logger
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.inputs.data
import
ProcessorInputs
from
vllm.inputs.parse
import
is_encoder_decoder_inputs
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.inputs.registry
import
InputRegistry
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.registry
import
MultiModalRegistry
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
,
ParallelSampleSequenceGroup
,
SequenceGroup
,
SequenceGroupBase
,
SequenceGroupMetadata
from
vllm.tracing
import
init_tracer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.usage.usage_lib
import
UsageContext
,
is_usage_stats_enabled
from
vllm.utils
import
resolve_obj_by_qualname
,
weak_bind
from
vllm.zero_overhead.v0.sequence
import
ZeroOverheadSequence
from
vllm.zero_overhead.v0.stop_check
import
ZeroOverheadStopChecker
from
vllm.zero_overhead.v0.tokenizer
import
ZeroOverheadDetokenizer
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.profiler.prof
import
profile
class
ZeroOverheadEngine
(
LLMEngine
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
executor_class
:
Type
[
ExecutorBase
],
log_stats
:
bool
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
use_cached_outputs
:
bool
=
False
,
)
->
None
:
if
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"Using V0 LLMEngine, but envs.VLLM_USE_V1=True. "
"This should not happen. As a workaround, try using "
"LLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github."
)
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
lora_config
=
vllm_config
.
lora_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
device_config
=
vllm_config
.
device_config
self
.
speculative_config
=
vllm_config
.
speculative_config
# noqa
self
.
load_config
=
vllm_config
.
load_config
self
.
decoding_config
=
vllm_config
.
decoding_config
or
DecodingConfig
(
# noqa
)
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
# noqa
self
.
observability_config
=
vllm_config
.
observability_config
or
ObservabilityConfig
(
# noqa
)
logger
.
info
(
"Initializing a V0 LLM engine (v%s) with config: %s, "
"use_cached_outputs=%s, "
,
ZLIB_VERSION
,
vllm_config
,
use_cached_outputs
,
)
self
.
log_stats
=
log_stats
self
.
use_cached_outputs
=
use_cached_outputs
if
not
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
detokenizer
=
ZeroOverheadDetokenizer
(
self
.
tokenizer
)
tokenizer_group
=
self
.
get_tokenizer_group
()
else
:
self
.
tokenizer
=
None
self
.
detokenizer
=
None
tokenizer_group
=
None
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def
get_tokenizer_for_seq
(
sequence
:
ZeroOverheadSequence
)
->
AnyTokenizer
:
assert
tokenizer_group
,
(
"tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False"
)
return
tokenizer_group
.
get_lora_tokenizer
(
sequence
.
lora_request
)
self
.
seq_counter
=
Counter
()
self
.
generation_config_fields
=
(
self
.
model_config
.
try_get_generation_config
())
self
.
input_preprocessor
=
InputPreprocessor
(
self
.
model_config
,
self
.
tokenizer
,
mm_registry
)
self
.
input_registry
=
input_registry
self
.
input_processor
=
input_registry
.
create_input_processor
(
self
.
model_config
)
self
.
model_executor
=
executor_class
(
vllm_config
=
vllm_config
,
)
if
self
.
model_config
.
runner_type
!=
"pooling"
:
self
.
_initialize_kv_caches
()
# If usage stat is enabled, collect relevant info.
if
is_usage_stats_enabled
():
from
vllm.model_executor.model_loader
import
(
get_architecture_class_name
)
usage_message
.
report_usage
(
get_architecture_class_name
(
self
.
model_config
),
usage_context
,
extra_kvs
=
{
# Common configuration
"dtype"
:
str
(
self
.
model_config
.
dtype
),
"tensor_parallel_size"
:
self
.
parallel_config
.
tensor_parallel_size
,
"block_size"
:
self
.
cache_config
.
block_size
,
"gpu_memory_utilization"
:
self
.
cache_config
.
gpu_memory_utilization
,
# Quantization
"quantization"
:
self
.
model_config
.
quantization
,
"kv_cache_dtype"
:
str
(
self
.
cache_config
.
cache_dtype
),
# Feature flags
"enable_lora"
:
bool
(
self
.
lora_config
),
"enable_prompt_adapter"
:
bool
(
self
.
prompt_adapter_config
),
"enable_prefix_caching"
:
self
.
cache_config
.
enable_prefix_caching
,
"enforce_eager"
:
self
.
model_config
.
enforce_eager
,
"disable_custom_all_reduce"
:
self
.
parallel_config
.
disable_custom_all_reduce
,
})
if
self
.
tokenizer
:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self
.
tokenizer
.
ping
()
self
.
cached_scheduler_outputs
=
[
SchedulerOutputState
()
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
scheduler_contexts
=
[
SchedulerContext
(
multi_step_stream_outputs
=
self
.
scheduler_config
.
multi_step_stream_outputs
)
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
if
self
.
model_config
.
use_async_output_proc
:
process_model_outputs
=
weak_bind
(
self
.
_process_model_outputs
)
self
.
async_callbacks
=
[
partial
(
process_model_outputs
,
ctx
=
self
.
scheduler_contexts
[
v_id
])
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
else
:
self
.
async_callbacks
=
[]
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
self
.
process_request_outputs_callback
:
Optional
[
Callable
]
=
None
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
if
isinstance
(
self
.
vllm_config
.
scheduler_config
.
scheduler_cls
,
str
):
Scheduler
=
resolve_obj_by_qualname
(
self
.
vllm_config
.
scheduler_config
.
scheduler_cls
)
else
:
Scheduler
=
self
.
vllm_config
.
scheduler_config
.
scheduler_cls
self
.
scheduler
=
[
Scheduler
(
self
.
scheduler_config
,
self
.
cache_config
,
self
.
lora_config
,
self
.
parallel_config
.
pipeline_parallel_size
,
self
.
async_callbacks
[
v_id
]
if
self
.
model_config
.
use_async_output_proc
else
None
)
for
v_id
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
# Metric Logging.
if
self
.
log_stats
:
if
stat_loggers
is
not
None
:
self
.
stat_loggers
=
stat_loggers
else
:
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from
vllm.engine.metrics
import
(
LoggingStatLogger
,
PrometheusStatLogger
)
self
.
stat_loggers
=
{
"logging"
:
LoggingStatLogger
(
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
,
vllm_config
=
vllm_config
),
"prometheus"
:
PrometheusStatLogger
(
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
,
labels
=
dict
(
model_name
=
self
.
model_config
.
served_model_name
),
vllm_config
=
vllm_config
),
}
self
.
stat_loggers
[
"prometheus"
].
info
(
"cache_config"
,
self
.
cache_config
)
self
.
tracer
=
None
if
self
.
observability_config
.
otlp_traces_endpoint
:
self
.
tracer
=
init_tracer
(
"vllm.llm_engine"
,
self
.
observability_config
.
otlp_traces_endpoint
)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self
.
output_processor
=
(
SequenceGroupOutputProcessor
.
create_output_processor
(
self
.
scheduler_config
,
self
.
detokenizer
,
self
.
scheduler
,
self
.
seq_counter
,
get_tokenizer_for_seq
,
stop_checker
=
ZeroOverheadStopChecker
(
self
.
scheduler_config
.
max_model_len
,
get_tokenizer_for_seq
,
),
))
self
.
tree_decoding
=
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
self
.
seq_id_to_seq_group
:
Dict
[
str
,
SequenceGroupBase
]
=
{}
# Flag to set when an input fails to process and the engine should run
# the next step without re-scheduling.
self
.
_skip_scheduling_next_step
=
False
self
.
async_d2h
=
None
self
.
last_record
=
None
self
.
async_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
zero_thread
=
threading
.
Thread
(
target
=
self
.
thread_zero_overhead
)
self
.
q_recorder
=
queue
.
Queue
()
self
.
thread_running
=
True
self
.
sem_m2s
=
threading
.
Semaphore
(
0
)
# main to scheduler thread
self
.
zero_thread
.
start
()
profile
.
StartTracer
()
def
__del__
(
self
):
self
.
finish_thread
()
return
super
().
__del__
()
def
finish_thread
(
self
):
if
self
.
thread_running
:
self
.
thread_running
=
False
self
.
sem_m2s
.
release
()
def
thread_zero_overhead
(
self
):
try
:
while
True
:
self
.
sem_m2s
.
acquire
()
if
not
self
.
thread_running
:
break
virtual_engine
=
0
# Clear outputs for each new scheduler iteration
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
last_outputs_ids
=
None
last_outputs_tensor
=
None
if
self
.
last_record
is
not
None
:
last_output
=
self
.
last_record
[
0
][
0
]
last_outputs_ids
,
last_outputs_tensor
=
last_output
.
sampler_out_ids
,
last_output
.
sampler_out_tenosr
self
.
async_d2h
=
last_outputs_tensor
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
async_event
.
record
()
self
.
q_recorder
.
put
(
self
.
last_record
)
else
:
self
.
q_recorder
.
put
(
None
)
if
len
(
seq_group_metadata_list
)
==
0
:
self
.
last_record
=
None
continue
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
last_sampled_token_ids
=
\
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
finished_requests_ids
=
finished_requests_ids
,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
,
last_outputs_ids
=
last_outputs_ids
,
last_outputs_sample
=
last_outputs_tensor
)
outputs
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
if
len
(
outputs
)
==
1
:
self
.
_advance_to_next_step
(
outputs
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
=
[
item
for
item
in
scheduler_outputs
.
scheduled_seq_groups
]
#deep copy
self
.
last_record
=
[
outputs
,
seq_group_metadata_list
,
scheduler_outputs
]
except
Exception
as
e
:
print
(
f
"thread_zero_overhead error :
{
e
}
"
)
traceback
.
print_exc
()
def
zero_overhead_step
(
self
)
->
List
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]:
if
not
self
.
thread_running
:
self
.
zero_thread
.
join
()
self
.
thread_running
=
True
self
.
zero_thread
=
threading
.
Thread
(
target
=
self
.
thread_zero_overhead
)
self
.
zero_thread
.
start
()
self
.
sem_m2s
.
release
()
recode_output
=
self
.
q_recorder
.
get
()
if
recode_output
is
None
:
# None is for the first step
return
None
virtual_engine
=
0
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
ctx
.
request_outputs
.
clear
()
outputs
,
seq_group_metadata_list
,
scheduler_outputs
=
recode_output
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
scheduler_outputs
=
scheduler_outputs
self
.
async_event
.
synchronize
()
self
.
_fix_last_step
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output
:
bool
=
False
if
not
seq_group_metadata_list
\
else
seq_group_metadata_list
[
0
].
state
.
num_steps
==
1
# Add results to the output_queue
ctx
.
append_output
(
outputs
=
outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
is_async
=
True
,
is_last_step
=
True
,
is_first_step_output
=
is_first_step_output
)
# Check if need to run the usual non-async path
#if not allow_async_output_proc:
self
.
_process_model_outputs
(
ctx
=
ctx
)
#profile.ProfRangeAutoPush('has_unfinish')
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
if
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
assert
len
(
ctx
.
output_queue
)
==
0
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
logger
.
debug
(
"Stopping remote worker execution loop."
)
self
.
model_executor
.
stop_remote_worker_execution_loop
()
return
ctx
.
request_outputs
def
_fix_last_step
(
self
,
output
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
])
->
None
:
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
sample_out_list
=
self
.
async_d2h
.
tolist
()
sample_out_ids
=
output
[
0
].
sampler_out_ids
.
tolist
()
for
seq_group_metadata
,
sequence_group_outputs
,
scheduled_seq_group
in
\
zip
(
seq_group_metadata_list
,
output
[
0
],
scheduled_seq_groups
):
seq_group
=
scheduled_seq_group
.
seq_group
if
seq_group
.
is_finished
():
continue
if
seq_group_metadata
.
do_sample
:
sample
=
sequence_group_outputs
.
samples
[
0
]
assert
len
(
seq_group
.
seqs
)
==
1
seq
:
ZeroOverheadSequence
=
seq_group
.
seqs
[
0
]
for
token_id
,
seq_id
in
zip
(
sample_out_list
,
sample_out_ids
):
if
seq
.
seq_id
==
seq_id
:
if
type
(
token_id
)
is
list
:
sample
.
output_token
=
token_id
[
0
]
else
:
sample
.
output_token
=
token_id
seq
.
fix_last_token_id
(
sample
.
output_token
)
break
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]:
out
=
self
.
zero_overhead_step
()
if
out
is
None
:
#the first step need launch twice
out
=
self
.
zero_overhead_step
()
return
out
def
_add_processed_request
(
self
,
request_id
:
str
,
processed_inputs
:
ProcessorInputs
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
float
,
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
Optional
[
SequenceGroup
]:
"""Add a processed request to the engine's request pool.
return the created sequence group.
"""
if
isinstance
(
params
,
SamplingParams
)
and
params
.
n
>
1
:
ParallelSampleSequenceGroup
.
add_request
(
request_id
,
self
,
params
,
processed_inputs
=
processed_inputs
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
)
return
None
self
.
_validate_model_inputs
(
processed_inputs
,
lora_request
)
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
seq_id
=
next
(
self
.
seq_counter
)
eos_token_id
=
self
.
input_preprocessor
.
get_eos_token_id
(
lora_request
)
if
is_encoder_decoder_inputs
(
processed_inputs
):
decoder_inputs
=
processed_inputs
[
"decoder"
]
encoder_inputs
=
processed_inputs
[
"encoder"
]
else
:
decoder_inputs
=
processed_inputs
encoder_inputs
=
None
seq
=
ZeroOverheadSequence
(
seq_id
,
decoder_inputs
,
block_size
,
eos_token_id
,
lora_request
,
prompt_adapter_request
)
encoder_seq
=
(
None
if
encoder_inputs
is
None
else
ZeroOverheadSequence
(
seq_id
,
encoder_inputs
,
block_size
,
eos_token_id
,
lora_request
,
prompt_adapter_request
))
# Create a SequenceGroup based on SamplingParams or PoolingParams
if
isinstance
(
params
,
SamplingParams
):
seq_group
=
self
.
_create_sequence_group_with_sampling
(
request_id
,
seq
,
params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
,
priority
=
priority
)
elif
isinstance
(
params
,
PoolingParams
):
seq_group
=
self
.
_create_sequence_group_with_pooling
(
request_id
,
seq
,
params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
,
priority
=
priority
)
else
:
raise
ValueError
(
"Either SamplingParams or PoolingParams must be provided."
)
# Add the sequence group to the scheduler with least unfinished seqs.
costs
=
[
scheduler
.
get_num_unfinished_seq_groups
()
for
scheduler
in
self
.
scheduler
]
min_cost_scheduler
=
self
.
scheduler
[
costs
.
index
(
min
(
costs
))]
min_cost_scheduler
.
add_seq_group
(
seq_group
)
return
seq_group
\ No newline at end of file
vllm/zero_overhead/v0/model_runner.py
0 → 100644
View file @
54294854
import
torch
import
itertools
from
typing
import
List
,
Optional
,
Set
from
vllm.lora.layers
import
LoRAMapping
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.prompt_adapter.layers
import
PromptAdapterMapping
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.utils
import
async_tensor_h2d
,
flatten_2d_lists
from
vllm.worker.model_runner
import
ModelInputForGPU
,
ModelInputForGPUBuilder
from
vllm.zero_overhead.v0.sampler
import
get_last_sampler
from
vllm.zero_overhead.v0.update_input
import
UpdateInputTokens
class
ZeroOverheadModelInputForGpuBuilder
(
ModelInputForGPUBuilder
):
def
__init__
(
self
,
runner
,
finished_requests_ids
=
None
):
super
().
__init__
(
runner
,
finished_requests_ids
)
self
.
req_ids
=
[]
def
prepare
(
self
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
self
.
req_ids
.
clear
()
return
super
().
prepare
(
finished_requests_ids
)
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
seq_ids
=
seq_group_metadata
.
seq_data
.
keys
()
n_seqs
=
len
(
seq_ids
)
seq_ids
=
list
(
seq_ids
)
for
seq_idx
in
range
(
n_seqs
):
self
.
req_ids
.
append
(
seq_ids
[
seq_idx
])
return
super
().
add_seq_group
(
seq_group_metadata
)
def
build
(
self
)
->
ModelInputForGPU
:
model_input
=
super
().
build
()
last_sampler
=
get_last_sampler
()
if
last_sampler
.
sampled_token_ids_tensor
is
not
None
:
input_ids
=
async_tensor_h2d
(
self
.
req_ids
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
last_ids
=
async_tensor_h2d
(
last_sampler
.
seq_id
.
tolist
(),
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
UpdateInputTokens
(
model_input
.
input_tokens
,
input_ids
,
last_sampler
.
sampled_token_ids_tensor
,
last_ids
)
return
model_input
vllm/zero_overhead/v0/sampler.py
0 → 100644
View file @
54294854
from
importlib.util
import
find_spec
from
typing
import
Dict
,
List
,
Optional
import
torch
from
vllm
import
envs
from
vllm.model_executor.layers.rejection_sampler
import
_multinomial
from
vllm.model_executor.layers.sampler
import
MultinomialSamplesType
,
SampleMetadataType
,
\
SampleResultArgsType
,
SampleResultType
,
SampleResultsDictType
,
SampleReturnType
,
Sampler
,
\
SamplerOutput
,
_apply_min_p
,
_apply_min_tokens_penalty
,
_apply_top_k_top_p
,
_build_sampler_output
,
\
_modify_greedy_probs_inplace
,
_top_k_top_p_multinomial_with_flashinfer
,
get_logprobs
from
vllm.model_executor.layers.utils
import
apply_penalties
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
,
SamplingTensors
,
SequenceGroupToSample
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
VLLM_INVALID_TOKEN_ID
if
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
find_spec
(
"flashinfer"
):
import
flashinfer.sampling
# yapf: disable
from
flashinfer.sampling
import
(
top_k_top_p_sampling_from_probs
as
flashinfer_top_k_top_p_sampling
)
class
SampleRecorder
:
def
__init__
(
self
):
self
.
seq_id
:
torch
.
Tensor
=
None
self
.
sampled_token_ids_tensor
:
torch
.
Tensor
=
None
last_sampler
=
SampleRecorder
()
def
get_last_sampler
():
return
last_sampler
class
ZeroOverheadSampler
(
Sampler
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
"""
Single-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Pythonize sampling result & logprobs tensor
Multi-step scheduling:
* Perform GPU-side sampling computation & compute
GPU-side logprobs tensor
* Defer Pythonization of sampling result & logprobs
tensor
* Encapsulate arguments required for deferred Pythonization
in the :class:`SamplerOutput` structure
Args:
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
assert
logits
is
not
None
_
,
vocab_size
=
logits
.
shape
# Prepare sampling tensors with pinned memory to avoid blocking.
if
not
sampling_metadata
.
reuse_sampling_tensors
:
self
.
_init_sampling_tensors
(
logits
,
sampling_metadata
)
elif
self
.
_do_penalties
:
# In this case, the sampling tensors logic depends on
# "output_tokens" of a sequence. As a result, we cannot
# reuse sampling tensors, since "output_tokens" changes
# between decode runs.
self
.
_init_sampling_tensors
(
logits
,
sampling_metadata
)
assert
self
.
_sampling_tensors
is
not
None
sampling_tensors
=
self
.
_sampling_tensors
do_penalties
=
self
.
_do_penalties
do_top_p_top_k
=
self
.
_do_top_p_top_k
do_min_p
=
self
.
_do_min_p
logits
=
_apply_min_tokens_penalty
(
logits
,
sampling_metadata
)
# Apply presence and frequency penalties.
if
do_penalties
:
logits
=
apply_penalties
(
logits
,
sampling_tensors
.
prompt_tokens
,
sampling_tensors
.
output_tokens
,
sampling_tensors
.
presence_penalties
,
sampling_tensors
.
frequency_penalties
,
sampling_tensors
.
repetition_penalties
)
# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits
=
logits
.
to
(
torch
.
float
)
logits
.
div_
(
sampling_tensors
.
temperatures
.
unsqueeze
(
dim
=
1
))
if
do_top_p_top_k
and
flashinfer_top_k_top_p_sampling
is
None
:
logits
=
_apply_top_k_top_p
(
logits
,
sampling_tensors
.
top_ps
,
sampling_tensors
.
top_ks
)
if
do_min_p
:
logits
=
_apply_min_p
(
logits
,
sampling_tensors
.
min_ps
)
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Compute the log probabilities.
logprobs
=
torch
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Sample the next tokens.
maybe_deferred_sample_results
,
maybe_sampled_tokens_tensor
=
_sample
(
probs
,
logprobs
,
sampling_metadata
,
sampling_tensors
,
include_gpu_probs_tensor
=
self
.
include_gpu_probs_tensor
,
modify_greedy_probs
=
self
.
_should_modify_greedy_probs_inplace
,
)
if
self
.
include_gpu_probs_tensor
:
# Since we will defer sampler result Pythonization,
# preserve GPU-side tensors in support of later
# deferred pythonization of logprobs
assert
maybe_sampled_tokens_tensor
is
not
None
on_device_tensors
=
(
probs
,
logprobs
,
maybe_sampled_tokens_tensor
)
else
:
# Since Pythonization has already happened, don't preserve
# GPU-side tensors.
on_device_tensors
=
None
# Get the logprobs query results.
prompt_logprobs
=
None
sample_logprobs
=
None
if
not
sampling_metadata
.
skip_sampler_cpu_output
:
# Pythonize logprobs now (GPU -> CPU); do not defer.
assert
not
isinstance
(
maybe_deferred_sample_results
,
SampleResultArgsType
)
prompt_logprobs
,
sample_logprobs
=
get_logprobs
(
logprobs
,
sampling_metadata
,
maybe_deferred_sample_results
)
return
_build_sampler_output
(
maybe_deferred_sample_results
,
sampling_metadata
,
prompt_logprobs
,
sample_logprobs
,
on_device_tensors
=
on_device_tensors
,
skip_sampler_cpu_output
=
sampling_metadata
.
skip_sampler_cpu_output
,
logits
=
logits
)
def
_greedy_sample
(
selected_seq_groups
:
List
[
SequenceGroupToSample
],
samples
:
torch
.
Tensor
,
)
->
SampleResultType
:
"""Run greedy sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
samples: (num_selected_samples,) A tensor of samples. The length of
samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
sample_idx
=
0
results
:
SampleResultType
=
[]
for
seq_group
in
selected_seq_groups
:
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
continue
seq_ids
=
seq_group
.
seq_ids
num_parent_seqs
=
len
(
seq_ids
)
assert
num_parent_seqs
==
1
,
(
"Greedy sampling should have only one seq."
)
parent_ids
=
list
(
range
(
num_parent_seqs
))
assert
num_parent_seqs
==
1
# not support muti seqences in seqence group
next_token_ids
=
[
0
]
#place holder token id
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
return
results
def
_random_sample
(
selected_seq_groups
:
List
[
SequenceGroupToSample
],
random_samples
:
torch
.
Tensor
,
)
->
SampleResultType
:
"""Run random sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
random_samples: (num_selected_samples,) A tensor of samples. The
length of samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum n value of the prompt phase requests.
sample_idx
=
0
results
:
SampleResultType
=
[]
for
seq_group
in
selected_seq_groups
:
if
not
seq_group
.
do_sample
:
results
.
append
(([],
[]))
continue
seq_ids
=
seq_group
.
seq_ids
sampling_params
=
seq_group
.
sampling_params
is_prompt
=
seq_group
.
is_prompt
num_parent_seqs
=
len
(
seq_ids
)
if
is_prompt
:
# Prompt phase.
parent_ids
=
[
0
]
*
sampling_params
.
n
assert
num_parent_seqs
==
1
# not support muti seqences in seqence group
next_token_ids
=
[
0
]
*
sampling_params
.
n
#place holder token id
else
:
# Generation phase.
parent_ids
=
list
(
range
(
num_parent_seqs
))
assert
num_parent_seqs
==
1
# not support muti seqences in seqence group
next_token_ids
=
[
0
]
*
num_parent_seqs
#place holder token id
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
return
results
def
_sample
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
)
->
SampleReturnType
:
"""
Args:
probs: (num_query_tokens_in_batch, num_vocab)
logprobs: (num_query_tokens_in_batch, num_vocab)
sampling_metadata: The metadata for a batch for sampling.
sampling_tensors: Tensors that include sampling related metadata.
Returns:
(next_token_ids, parent_seq_ids) for each seq group in a batch.
If sampling is skipped, it returns ([], [])
sampled_token_ids_tensor: A tensor of sampled token ids.
"""
return
_sample_with_torch
(
probs
,
logprobs
,
sampling_metadata
,
sampling_tensors
,
include_gpu_probs_tensor
=
include_gpu_probs_tensor
,
modify_greedy_probs
=
modify_greedy_probs
,
)
def
_sample_with_torch
(
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_tensors
:
SamplingTensors
,
include_gpu_probs_tensor
:
bool
,
modify_greedy_probs
:
bool
,
)
->
SampleReturnType
:
'''Torch-oriented _sample() implementation.
Single-step scheduling:
* Perform GPU-side sampling computation
* Immediately Pythonize sampling result
Multi-step scheduling:
* Perform GPU-side sampling computation
* Defer Pythonization & preserve GPU-side
tensors required for Pythonization
'''
categorized_seq_group_ids
:
Dict
[
SamplingType
,
List
[
int
]]
=
{
t
:
[]
for
t
in
SamplingType
}
last_sampler
.
seq_id
=
torch
.
zeros
(
len
(
sampling_metadata
.
seq_groups
),
dtype
=
torch
.
int32
)
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
last_sampler
.
seq_id
[
i
]
=
seq_group
.
seq_ids
[
0
]
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
sample_results_dict
:
SampleResultsDictType
=
{}
sample_metadata
:
SampleMetadataType
=
{}
multinomial_samples
:
MultinomialSamplesType
=
{}
greedy_samples
:
Optional
[
torch
.
Tensor
]
=
None
beam_search_logprobs
:
Optional
[
torch
.
Tensor
]
=
None
# Create output tensor for sampled token ids.
if
include_gpu_probs_tensor
:
sampled_token_ids_tensor
=
torch
.
full
((
logprobs
.
shape
[
0
],
1
),
VLLM_INVALID_TOKEN_ID
,
dtype
=
torch
.
long
,
device
=
logprobs
.
device
)
else
:
sampled_token_ids_tensor
=
None
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for
sampling_type
in
SamplingType
:
sample_indices
=
categorized_sample_indices
[
sampling_type
]
num_tokens
=
len
(
sample_indices
)
if
num_tokens
==
0
:
continue
seq_group_id
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
sampling_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_id
]
sample_metadata
[
sampling_type
]
=
(
seq_group_id
,
seq_groups
)
long_sample_indices
=
sample_indices
.
long
()
if
sampling_type
==
SamplingType
.
GREEDY
:
greedy_samples
=
torch
.
argmax
(
logprobs
[
long_sample_indices
],
dim
=-
1
)
last_sampler
.
sampled_token_ids_tensor
=
greedy_samples
.
unsqueeze
(
-
1
)
if
sampled_token_ids_tensor
is
not
None
:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor
[
long_sample_indices
]
=
greedy_samples
.
unsqueeze
(
-
1
)
if
modify_greedy_probs
:
# If required, modify the probabilities such that sampling from
# the modified distribution would always sample the argmax
# token id.
_modify_greedy_probs_inplace
(
logprobs
,
probs
,
long_sample_indices
,
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
max_n_in_batch
=
1
for
seq_group
in
seq_groups
:
if
seq_group
.
is_prompt
:
sampling_params
=
seq_group
.
sampling_params
max_n_in_batch
=
max
(
max_n_in_batch
,
sampling_params
.
n
)
seq_groups_arg
=
(
None
if
sampling_type
==
SamplingType
.
RANDOM
else
seq_groups
)
if
flashinfer_top_k_top_p_sampling
is
not
None
:
multinomial_samples
[
sampling_type
]
=
_top_k_top_p_multinomial_with_flashinfer
(
probs
[
long_sample_indices
],
sampling_tensors
.
top_ks
[
long_sample_indices
],
sampling_tensors
.
top_ps
[
long_sample_indices
],
max_n_in_batch
,
seq_groups_arg
,
)
else
:
multinomial_samples
[
sampling_type
]
=
_multinomial
(
probs
[
long_sample_indices
],
max_n_in_batch
,
seq_groups
=
seq_groups_arg
)
last_sampler
.
sampled_token_ids_tensor
=
\
multinomial_samples
[
sampling_type
].
to
(
torch
.
long
)
if
sampled_token_ids_tensor
is
not
None
:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor
[
long_sample_indices
]
=
\
multinomial_samples
[
sampling_type
].
to
(
torch
.
long
)
elif
sampling_type
==
SamplingType
.
BEAM
:
beam_search_logprobs
=
logprobs
[
sample_indices
]
else
:
raise
ValueError
(
f
"Unsupported sampling type:
{
sampling_type
}
"
)
# Encapsulate arguments for computing Pythonized sampler
# results, whether deferred or otherwise.
maybe_deferred_args
=
SampleResultArgsType
(
sampling_metadata
=
sampling_metadata
,
sample_metadata
=
sample_metadata
,
multinomial_samples
=
multinomial_samples
,
greedy_samples
=
greedy_samples
,
beam_search_logprobs
=
beam_search_logprobs
,
sample_results_dict
=
sample_results_dict
)
if
not
sampling_metadata
.
skip_sampler_cpu_output
:
# GPU<->CPU sync happens here.
# This also converts the sampler output to a Python object.
# Return Pythonized sampler result & sampled token ids
return
get_pythonized_sample_results
(
maybe_deferred_args
),
sampled_token_ids_tensor
else
:
# Defer sampler result Pythonization; return deferred
# Pythonization args & sampled token ids
return
(
maybe_deferred_args
,
sampled_token_ids_tensor
,
)
def
get_pythonized_sample_results
(
sample_result_args
:
SampleResultArgsType
)
->
SampleResultType
:
'''This function consumes GPU-side sampler results and computes
Pythonized CPU-side sampler results (GPU -> CPU sync.)
Single-step scheduling: this function is invoked at sampling-time
for immediate Pythonization.
Multi-step scheduling: Pythonization is deferred until after multiple
GPU-side steps have been completed.
Args:
sample_result_args: GPU-side inputs to the Pythonization process
Returns:
Pythonized sampler results
'''
(
sample_metadata
,
sampling_metadata
,
greedy_samples
,
multinomial_samples
,
sample_results_dict
,
)
=
(
sample_result_args
.
sample_metadata
,
sample_result_args
.
sampling_metadata
,
sample_result_args
.
greedy_samples
,
sample_result_args
.
multinomial_samples
,
sample_result_args
.
sample_results_dict
,
)
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
continue
(
seq_group_id
,
seq_groups
)
=
sample_metadata
[
sampling_type
]
if
sampling_type
==
SamplingType
.
GREEDY
:
sample_results
=
_greedy_sample
(
seq_groups
,
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
sample_results
=
_random_sample
(
seq_groups
,
multinomial_samples
[
sampling_type
])
sample_results_dict
.
update
(
zip
(
seq_group_id
,
sample_results
))
return
[
sample_results_dict
.
get
(
i
,
([],
[]))
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
\ No newline at end of file
vllm/zero_overhead/v0/sequence.py
0 → 100644
View file @
54294854
from
typing
import
Union
from
vllm.sequence
import
Sequence
from
typing
import
Sequence
as
GenericSequence
class
ZeroOverheadSequence
(
Sequence
):
def
__init__
(
self
,
seq_id
,
inputs
,
block_size
,
eos_token_id
=
None
,
lora_request
=
None
,
prompt_adapter_request
=
None
):
super
().
__init__
(
seq_id
,
inputs
,
block_size
,
eos_token_id
,
lora_request
,
prompt_adapter_request
)
self
.
effective_output_len
:
int
=
0
def
fix_last_token_id
(
self
,
token_id
:
int
)
->
None
:
effect_offset
=
self
.
effective_output_len
-
len
(
self
.
data
.
output_token_ids
)
assert
effect_offset
<
0
self
.
data
.
_output_token_ids
[
effect_offset
]
=
token_id
if
len
(
self
.
data
.
_new_appended_tokens
)
>=
effect_offset
*
-
1
:
self
.
data
.
_new_appended_tokens
[
effect_offset
]
=
token_id
self
.
data
.
_cached_all_token_ids
[
effect_offset
]
=
token_id
self
.
effective_output_len
+=
1
def
zero_overhead_get_output_token_ids
(
self
)
->
tuple
[
int
,
...]:
return
self
.
data
.
output_token_ids
[:
self
.
effective_output_len
]
def
zero_overhead_get_output_len
(
self
)
->
int
:
return
self
.
effective_output_len
def
zero_overhead_get_last_token_id
(
self
)
->
int
:
if
self
.
effective_output_len
==
0
:
return
self
.
data
.
_prompt_token_ids
[
-
1
]
return
self
.
data
.
_output_token_ids
[
self
.
effective_output_len
-
1
]
def
zero_overhead_get_len
(
self
)
->
int
:
return
self
.
effective_output_len
+
len
(
self
.
data
.
_prompt_token_ids
)
def
get_output_token_ids_to_return
(
self
,
delta
:
bool
)
->
Union
[
GenericSequence
[
int
],
int
]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if
not
delta
:
return
self
.
zero_overhead_get_output_token_ids
()
output_len
=
self
.
zero_overhead_get_output_len
()
# Get the number of new tokens
num_new_tokens
=
output_len
-
self
.
_last_output_token_ids_offset
self
.
_last_output_token_ids_offset
=
output_len
# Return new tokens
if
num_new_tokens
==
1
:
# Optimization for single decode token case
# (which is what we have most of the time)
return
self
.
data
.
_cached_all_token_ids
[
self
.
effective_output_len
-
1
]
if
num_new_tokens
==
0
:
return
[]
effect_offset
=
self
.
effective_output_len
-
len
(
self
.
data
.
output_token_ids
)
return
self
.
data
.
_cached_all_token_ids
[
-
num_new_tokens
:
effect_offset
]
\ No newline at end of file
vllm/zero_overhead/v0/stop_check.py
0 → 100644
View file @
54294854
from
typing
import
Optional
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SequenceStatus
from
vllm.zero_overhead.v0.sequence
import
ZeroOverheadSequence
class
ZeroOverheadStopChecker
(
StopChecker
):
def
__init__
(
self
,
max_model_len
,
get_tokenizer_for_seq
):
super
().
__init__
(
max_model_len
,
get_tokenizer_for_seq
)
def
maybe_stop_sequence
(
self
,
seq
:
ZeroOverheadSequence
,
new_char_count
:
int
,
sampling_params
:
SamplingParams
,
lora_req
:
Optional
[
LoRARequest
]
=
None
,
)
->
None
:
"""Stop the finished sequences.
new_char_count is the number of chars added to the
sequence's output text for the newly generated token
"""
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if
seq
.
zero_overhead_get_output_len
()
<
sampling_params
.
min_tokens
:
return
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
zero_overhead_get_last_token_id
()
==
seq
.
eos_token_id
):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id
=
seq
.
get_last_token_id
(
self
.
zero_overhead
)
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
# Remove last token
seq
.
output_text
=
seq
.
output_text
[:
-
new_char_count
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
last_token_id
return
# Check if any stop strings are matched.
stop
=
self
.
check_stop_strings
(
seq
.
output_text
,
new_char_count
,
sampling_params
.
stop
,
sampling_params
.
include_stop_str_in_output
)
if
stop
is
not
None
:
stop_str
,
truncate_to
=
stop
if
truncate_to
!=
-
1
:
seq
.
output_text
=
seq
.
output_text
[:
truncate_to
]
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
seq
.
stop_reason
=
stop_str
return
# Check if the sequence has reached max_model_len.
if
seq
.
zero_overhead_get_len
()
>
self
.
_get_max_model_len
(
lora_req
):
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if
seq
.
zero_overhead_get_output_len
()
==
sampling_params
.
max_tokens
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
\ No newline at end of file
vllm/zero_overhead/v0/tokenizer.py
0 → 100644
View file @
54294854
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
VLLM_INVALID_TOKEN_ID
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer_utils
import
convert_prompt_ids_to_tokens
,
detokenize_incrementally
from
vllm.zero_overhead.v0.sequence
import
ZeroOverheadSequence
class
ZeroOverheadDetokenizer
(
Detokenizer
):
def
__init__
(
self
,
tokenizer_group
):
super
().
__init__
(
tokenizer_group
)
def
decode_sequence_inplace
(
self
,
seq
:
ZeroOverheadSequence
,
prms
:
SamplingParams
)
->
int
:
"""Decodes the new token for a sequence. In-place operation.
Args:
seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence.
Returns:
The number of characters added to the output text.
"""
eff_length
=
seq
.
get_prompt_len
()
+
seq
.
effective_output_len
all_input_ids
=
seq
.
get_token_ids
()[
:
eff_length
]
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
# Convert prompt token IDs to tokens if necessary.
# Do it here so that we don't have to repeat this
# computation for each logprob.
if
seq
.
tokens
is
None
:
(
seq
.
tokens
,
seq
.
prefix_offset
,
seq
.
read_offset
)
=
convert_prompt_ids_to_tokens
(
tokenizer
=
tokenizer
,
prompt_ids
=
all_input_ids
[:
-
1
],
skip_special_tokens
=
prms
.
skip_special_tokens
,
)
(
new_tokens
,
new_decoded_token_text
,
prefix_offset
,
read_offset
)
=
detokenize_incrementally
(
tokenizer
=
tokenizer
,
all_input_ids
=
all_input_ids
,
prev_tokens
=
seq
.
tokens
,
prefix_offset
=
seq
.
prefix_offset
,
read_offset
=
seq
.
read_offset
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
)
# Decode logprobs
logprobs
=
seq
.
output_logprobs
[
-
1
]
if
logprobs
:
previous_tokens
=
all_input_ids
[:
-
1
]
for
token_id
,
sample_logprob
in
logprobs
.
items
():
# If the token was generated this iteration,
# use the provided text.
if
token_id
==
token_id_generated_this_iteration
:
sample_logprob
.
decoded_token
=
new_decoded_token_text
continue
if
(
sample_logprob
.
decoded_token
is
None
and
token_id
!=
VLLM_INVALID_TOKEN_ID
):
all_input_ids_with_logprob
=
previous_tokens
+
[
token_id
]
(
_
,
new_text
,
_
,
_
)
=
detokenize_incrementally
(
tokenizer
=
tokenizer
,
all_input_ids
=
all_input_ids_with_logprob
,
prev_tokens
=
seq
.
tokens
,
prefix_offset
=
seq
.
prefix_offset
,
read_offset
=
seq
.
read_offset
,
skip_special_tokens
=
prms
.
skip_special_tokens
,
spaces_between_special_tokens
=
prms
.
spaces_between_special_tokens
,
)
sample_logprob
.
decoded_token
=
new_text
seq
.
tokens
.
extend
(
new_tokens
)
seq
.
prefix_offset
=
prefix_offset
seq
.
read_offset
=
read_offset
seq
.
output_text
+=
new_decoded_token_text
return
len
(
new_decoded_token_text
)
\ No newline at end of file
vllm/zero_overhead/v0/update_input.py
0 → 100644
View file @
54294854
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_update_input_tokens
(
sample_output
,
seq_ids
,
input_tokens
,
input_seq_ids
,
BATCH_SIZE1
,
BATCH_SIZE2
,
):
pid
=
tl
.
program_id
(
0
)
if
pid
>=
BATCH_SIZE2
:
return
output_token
=
tl
.
load
(
input_tokens
+
pid
)
_input_seq_id
=
tl
.
load
(
input_seq_ids
+
pid
)
for
i
in
range
(
BATCH_SIZE1
):
_seq_ids
=
tl
.
load
(
seq_ids
+
i
)
if
_seq_ids
==
_input_seq_id
:
output_token
=
tl
.
load
(
sample_output
+
i
)
tl
.
store
(
input_tokens
+
pid
,
output_token
)
def
UpdateInputTokens
(
input_tokens
,
input_seq_ids
,
last_sample
,
last_ids
):
grid
=
[
input_seq_ids
.
shape
[
0
],
1
,
1
]
_update_input_tokens
[
grid
](
last_sample
,
last_ids
,
input_tokens
,
input_seq_ids
,
last_ids
.
shape
[
0
],
input_seq_ids
.
shape
[
0
])
\ No newline at end of file
vllm/zero_overhead/v0/utils.py
0 → 100644
View file @
54294854
import
os
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
def
is_zero_overhead
():
return
zero_overhead
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment