Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
645e9ec4
Commit
645e9ec4
authored
Apr 17, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'origin/v0.7.2_zero_overhead' into v0.7.2-dev
parents
d0de006f
c78f6594
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
480 additions
and
65 deletions
+480
-65
benchmarks/benchmark_serving.py
benchmarks/benchmark_serving.py
+2
-0
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+5
-4
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+4
-5
vllm/benchmarks/benchmark_throughput.py
vllm/benchmarks/benchmark_throughput.py
+4
-1
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+0
-2
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+181
-4
vllm/engine/output_processor/stop_checker.py
vllm/engine/output_processor/stop_checker.py
+7
-5
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+5
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+51
-12
vllm/model_executor/layers/update_input.py
vllm/model_executor/layers/update_input.py
+28
-0
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+0
-2
vllm/profiler/prof.py
vllm/profiler/prof.py
+73
-0
vllm/sequence.py
vllm/sequence.py
+67
-15
vllm/spec_decode/target_model_runner.py
vllm/spec_decode/target_model_runner.py
+4
-2
vllm/transformers_utils/detokenizer.py
vllm/transformers_utils/detokenizer.py
+7
-1
vllm/version.py
vllm/version.py
+6
-4
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+31
-6
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+2
-0
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+3
-2
No files found.
benchmarks/benchmark_serving.py
View file @
645e9ec4
...
...
@@ -570,6 +570,8 @@ async def benchmark(
else
:
print
(
"Initial test run completed. Starting main benchmark run..."
)
time
.
sleep
(
0.1
)
# ZERO_OVERHEAD : sleep and wait the last step in warmup
if
profile
:
print
(
"Starting profiler..."
)
profile_input
=
RequestFuncInput
(
model
=
model_id
,
...
...
benchmarks/benchmark_throughput.py
View file @
645e9ec4
...
...
@@ -8,7 +8,7 @@ import time
from
pathlib
import
Path
from
functools
import
cache
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
os
import
numpy
as
np
import
torch
import
uvloop
...
...
@@ -180,7 +180,7 @@ def run_vllm(
sampling_params
:
List
[
SamplingParams
]
=
[]
for
request
in
requests
:
prompts
.
append
(
TextPrompt
(
prompt
=
request
.
prompt
,
TextPrompt
(
prompt
=
"helloword"
,
multi_modal_data
=
request
.
multi_modal_data
))
sampling_params
.
append
(
SamplingParams
(
...
...
@@ -206,15 +206,16 @@ def run_vllm(
dummy_prompts
:
List
[
PromptType
]
=
[{
"prompt_token_ids"
:
batch
}
for
batch
in
dummy_prompt_token_ids
.
tolist
()]
print
(
f
'
{
os
.
environ
.
get
(
"VLLM_ZERO_OVERHEAD"
)
==
"1"
}
'
)
print
(
"Warming up..."
)
for
_
in
tqdm
(
range
(
num_iters_warmup
),
desc
=
"Warmup iterations"
):
llm
.
generate
(
dummy_prompts
,
sampling_params
=
warmup_sampling_params
,
use_tqdm
=
False
)
use_beam_search
=
False
print
(
"testing"
)
if
not
use_beam_search
:
if
args
.
profile
:
profile_dir
=
args
.
profile_result_dir
...
...
vllm/attention/backends/utils.py
View file @
645e9ec4
...
...
@@ -14,8 +14,6 @@ from vllm.attention.backends.abstract import AttentionType
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
if
TYPE_CHECKING
:
from
vllm.worker.model_runner_base
import
ModelRunnerBase
...
...
@@ -235,8 +233,10 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
for
i
,
block_table
in
enumerate
(
self
.
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
from_numpy
(
input_block_tables
).
to
(
device
,
non_blocking
=
True
)
# block_tables = torch.from_numpy(input_block_tables).to(
# device, non_blocking=True)
block_tables
=
async_tensor_h2d
(
input_block_tables
.
tolist
(),
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
else
:
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
...
...
@@ -245,7 +245,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
device
=
device
,
)
assert
max_query_len
>
0
,
"query_lens: {}"
.
format
(
query_lens
)
assert
device
is
not
None
context_lens_tensor
=
async_tensor_h2d
(
self
.
context_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
...
...
vllm/benchmarks/benchmark_throughput.py
View file @
645e9ec4
...
...
@@ -3,6 +3,7 @@
import
argparse
import
dataclasses
import
json
import
os
import
random
import
time
from
pathlib
import
Path
...
...
@@ -214,7 +215,9 @@ def run_vllm(
use_tqdm
=
False
)
use_beam_search
=
False
if
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
:
print
(
"sleep 1"
)
time
.
sleep
(
1
)
# ZERO_OVERHEAD : sleep and wait the last step in warmup
if
not
use_beam_search
:
if
args
.
profile
:
profile_dir
=
args
.
profile_result_dir
...
...
vllm/engine/async_llm_engine.py
View file @
645e9ec4
...
...
@@ -726,7 +726,6 @@ class AsyncLLMEngine(EngineClient):
"""Kick the engine to process the waiting requests.
Returns True if there are in-progress requests."""
new_requests
,
aborted_requests
=
(
self
.
_request_tracker
.
get_new_and_aborted_requests
())
...
...
@@ -746,7 +745,6 @@ class AsyncLLMEngine(EngineClient):
await
self
.
_engine_abort
(
aborted_requests
)
request_outputs
=
await
self
.
engine
.
step_async
(
virtual_engine
)
# Put the outputs into the corresponding streams.
# If used as a callback, then already invoked inside
# LLMEngine's _process_model_outputs
...
...
vllm/engine/llm_engine.py
View file @
645e9ec4
...
...
@@ -3,11 +3,14 @@
import
os
import
copy
import
time
import
threading
import
queue
from
collections
import
Counter
as
collectionsCounter
from
collections
import
deque
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
functools
import
partial
import
traceback
from
typing
import
(
TYPE_CHECKING
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
...
...
@@ -61,6 +64,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message
)
from
vllm.utils
import
Counter
,
Device
,
deprecate_kwargs
,
weak_bind
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.profiler.prof
import
profile
logger
=
init_logger
(
__name__
)
_LOCAL_LOGGING_INTERVAL_SEC
=
5
...
...
@@ -407,6 +411,19 @@ class LLMEngine:
self
.
tree_decoding
=
os
.
environ
.
get
(
'VLLM_TREE_DECODING'
)
==
'1'
self
.
seq_id_to_seq_group
:
Dict
[
str
,
SequenceGroupBase
]
=
{}
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
if
self
.
zero_overhead
:
assert
os
.
environ
.
get
(
'HIP_ALLOC_INITIALIZE'
)
==
'0'
self
.
async_d2h
=
None
self
.
last_record
=
None
self
.
async_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
zero_thread
=
threading
.
Thread
(
target
=
self
.
thread_zero_overhead
)
self
.
q_recorder
=
queue
.
Queue
()
self
.
thread_running
=
True
self
.
sem_m2s
=
threading
.
Semaphore
(
0
)
# main to scheduler thread
self
.
zero_thread
.
start
()
profile
.
StartTracer
()
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
...
...
@@ -1227,6 +1244,35 @@ class LLMEngine:
return
None
def
_fix_last_step
(
self
,
output
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
])
->
None
:
#sample_out_list = output[0].sampler_out_tenosr.cpu().tolist()
sample_out_list
=
self
.
async_d2h
.
tolist
()
sample_out_ids
=
output
[
0
].
sampler_out_ids
.
tolist
()
for
seq_group_metadata
,
sequence_group_outputs
,
scheduled_seq_group
in
\
zip
(
seq_group_metadata_list
,
output
[
0
],
scheduled_seq_groups
):
seq_group
=
scheduled_seq_group
.
seq_group
if
seq_group
.
is_finished
():
continue
if
seq_group_metadata
.
do_sample
:
sample
=
sequence_group_outputs
.
samples
[
0
]
assert
len
(
seq_group
.
seqs
)
==
1
seq
=
seq_group
.
seqs
[
0
]
for
token_id
,
seq_id
in
zip
(
sample_out_list
,
sample_out_ids
):
if
seq
.
seq_id
==
seq_id
:
if
type
(
token_id
)
is
list
:
sample
.
output_token
=
token_id
[
0
]
else
:
sample
.
output_token
=
token_id
seq
.
fix_last_token_id
(
sample
.
output_token
)
break
def
_advance_to_next_step
(
self
,
output
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
...
...
@@ -1270,6 +1316,131 @@ class LLMEngine:
seq_group
.
update_num_computed_tokens
(
1
)
else
:
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
def
finish_thread
(
self
):
if
self
.
zero_overhead
and
self
.
thread_running
:
self
.
thread_running
=
False
self
.
sem_m2s
.
release
()
def
thread_zero_overhead
(
self
):
logger
.
info
(
'zero overhead thread start!'
)
try
:
while
True
:
self
.
sem_m2s
.
acquire
()
if
not
self
.
thread_running
:
break
virtual_engine
=
0
# Clear outputs for each new scheduler iteration
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
)
=
self
.
scheduler
[
virtual_engine
].
schedule
()
last_outputs_ids
=
None
last_outputs_tensor
=
None
if
self
.
last_record
is
not
None
:
last_output
=
self
.
last_record
[
0
][
0
]
last_outputs_ids
,
last_outputs_tensor
=
last_output
.
sampler_out_ids
,
last_output
.
sampler_out_tenosr
self
.
async_d2h
=
last_outputs_tensor
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
async_event
.
record
()
self
.
q_recorder
.
put
(
self
.
last_record
)
else
:
self
.
q_recorder
.
put
(
None
)
if
len
(
seq_group_metadata_list
)
==
0
:
self
.
last_record
=
None
continue
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
last_sampled_token_ids
=
\
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
finished_requests_ids
=
finished_requests_ids
,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
,
last_outputs_ids
=
last_outputs_ids
,
last_outputs_sample
=
last_outputs_tensor
)
outputs
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
if
len
(
outputs
)
==
1
:
self
.
_advance_to_next_step
(
outputs
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
=
[
item
for
item
in
scheduler_outputs
.
scheduled_seq_groups
]
#deep copy
self
.
last_record
=
[
outputs
,
seq_group_metadata_list
,
scheduler_outputs
]
except
Exception
as
e
:
print
(
f
"thread_zero_overhead error :
{
e
}
"
)
traceback
.
print_exc
()
def
zero_overhead_step
(
self
)
->
List
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]:
if
not
self
.
thread_running
:
self
.
zero_thread
.
join
()
self
.
thread_running
=
True
self
.
zero_thread
=
threading
.
Thread
(
target
=
self
.
thread_zero_overhead
)
self
.
zero_thread
.
start
()
self
.
sem_m2s
.
release
()
recode_output
=
self
.
q_recorder
.
get
()
if
recode_output
is
None
:
# None is for the first step
return
None
virtual_engine
=
0
ctx
=
self
.
scheduler_contexts
[
virtual_engine
]
ctx
.
request_outputs
.
clear
()
outputs
,
seq_group_metadata_list
,
scheduler_outputs
=
recode_output
ctx
.
seq_group_metadata_list
=
seq_group_metadata_list
ctx
.
scheduler_outputs
=
scheduler_outputs
self
.
async_event
.
synchronize
()
self
.
_fix_last_step
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output
:
bool
=
False
if
not
seq_group_metadata_list
\
else
seq_group_metadata_list
[
0
].
state
.
num_steps
==
1
# Add results to the output_queue
ctx
.
append_output
(
outputs
=
outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
is_async
=
True
,
is_last_step
=
True
,
is_first_step_output
=
is_first_step_output
)
# Check if need to run the usual non-async path
#if not allow_async_output_proc:
self
.
_process_model_outputs
(
ctx
=
ctx
)
#profile.ProfRangeAutoPush('has_unfinish')
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
if
len
(
ctx
.
output_queue
)
>
0
:
self
.
_process_model_outputs
(
ctx
=
ctx
)
assert
len
(
ctx
.
output_queue
)
==
0
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
logger
.
debug
(
"Stopping remote worker execution loop."
)
self
.
model_executor
.
stop_remote_worker_execution_loop
()
return
ctx
.
request_outputs
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]:
"""Performs one decoding iteration and returns newly generated results.
...
...
@@ -1322,6 +1493,13 @@ class LLMEngine:
>>> if not (engine.has_unfinished_requests() or example_inputs):
>>> break
"""
#traceback.print_stack()
if
self
.
zero_overhead
:
out
=
self
.
zero_overhead_step
()
if
out
is
None
:
#the first step need launch twice
out
=
self
.
zero_overhead_step
()
return
out
if
self
.
parallel_config
.
pipeline_parallel_size
>
1
:
raise
NotImplementedError
(
"Pipeline parallelism is only supported through AsyncLLMEngine "
...
...
@@ -1395,14 +1573,14 @@ class LLMEngine:
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
)
if
allow_async_output_proc
:
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
#profile.ProfRangeAutoPush('model_executor')
outputs
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
#profile.ProfRangeAutoPush('end_executor')
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if
self
.
scheduler_config
.
is_multi_step
:
...
...
@@ -1442,7 +1620,6 @@ class LLMEngine:
if
outputs
and
allow_async_output_proc
:
assert
len
(
outputs
)
==
1
,
(
"Async postprocessor expects only a single output set"
)
self
.
_advance_to_next_step
(
outputs
[
0
],
seq_group_metadata_list
,
scheduler_outputs
.
scheduled_seq_groups
)
...
...
@@ -1460,6 +1637,7 @@ class LLMEngine:
# Multi-step case
return
ctx
.
request_outputs
#profile.ProfRangeAutoPush('has_unfinish')
if
not
self
.
has_unfinished_requests
():
# Drain async postprocessor (if exists)
if
len
(
ctx
.
output_queue
)
>
0
:
...
...
@@ -1473,7 +1651,6 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
logger
.
debug
(
"Stopping remote worker execution loop."
)
self
.
model_executor
.
stop_remote_worker_execution_loop
()
return
ctx
.
request_outputs
def
_has_remaining_steps
(
...
...
vllm/engine/output_processor/stop_checker.py
View file @
645e9ec4
# SPDX-License-Identifier: Apache-2.0
import
os
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
vllm.lora.request
import
LoRARequest
...
...
@@ -20,6 +21,7 @@ class StopChecker:
# Do not use it directly, but use `self._get_max_model_len`.
self
.
_max_model_len
=
max_model_len
self
.
get_tokenizer_for_seq
=
get_tokenizer_for_seq
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
def
_get_max_model_len
(
self
,
lora_req
:
Optional
[
LoRARequest
]):
if
lora_req
and
lora_req
.
long_lora_max_len
:
...
...
@@ -42,12 +44,12 @@ class StopChecker:
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if
seq
.
get_output_len
()
<
sampling_params
.
min_tokens
:
if
seq
.
get_output_len
(
self
.
zero_overhead
)
<
sampling_params
.
min_tokens
:
return
# Check if the sequence has generated the EOS token.
if
((
not
sampling_params
.
ignore_eos
)
and
seq
.
get_last_token_id
()
==
seq
.
eos_token_id
):
and
seq
.
get_last_token_id
(
self
.
zero_overhead
)
==
seq
.
eos_token_id
):
# Remove the last EOS token unless explicitly specified
# This prevents unintended exposure of the EOS token
if
new_char_count
and
(
...
...
@@ -58,7 +60,7 @@ class StopChecker:
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id
=
seq
.
get_last_token_id
()
last_token_id
=
seq
.
get_last_token_id
(
self
.
zero_overhead
)
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
...
...
@@ -81,12 +83,12 @@ class StopChecker:
return
# Check if the sequence has reached max_model_len.
if
seq
.
get_len
()
>
self
.
_get_max_model_len
(
lora_req
):
if
seq
.
get_len
(
self
.
zero_overhead
)
>
self
.
_get_max_model_len
(
lora_req
):
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
if
seq
.
get_output_len
(
self
.
zero_overhead
)
==
sampling_params
.
max_tokens
:
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
...
...
vllm/entrypoints/llm.py
View file @
645e9ec4
...
...
@@ -243,6 +243,9 @@ class LLM:
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
self
.
request_counter
=
Counter
()
def
__del__
(
self
):
self
.
llm_engine
.
finish_thread
()
@
staticmethod
def
get_engine_class
()
->
Type
[
LLMEngine
]:
...
...
@@ -1408,6 +1411,8 @@ class LLM:
if
use_tqdm
:
pbar
.
close
()
self
.
llm_engine
.
finish_thread
()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
...
...
vllm/model_executor/layers/sampler.py
View file @
645e9ec4
# SPDX-License-Identifier: Apache-2.0
"""A layer that samples the next tokens from the model's outputs."""
import
itertools
import
os
import
warnings
from
dataclasses
import
dataclass
from
importlib.util
import
find_spec
...
...
@@ -69,7 +70,15 @@ class SampleResultArgsType:
sampling_metadata
:
SamplingMetadata
greedy_samples
:
Optional
[
torch
.
Tensor
]
beam_search_logprobs
:
Optional
[
torch
.
Tensor
]
# Implemented by guanyu
@
dataclass
class
SampleDeviceToDevices
:
def
__init__
(
self
):
self
.
seq_id
:
torch
.
Tensor
=
None
self
.
sampled_token_ids_tensor
:
torch
.
Tensor
=
None
self
.
zero_overhead
:
bool
=
False
d2d_data
=
SampleDeviceToDevices
()
# Union of non-deferred (single-step scheduling)
# vs deferred (multi-step scheduling)
...
...
@@ -137,6 +146,9 @@ class SamplerOutput(
# tree-style cartesian candidates
tree_attn_masks
:
Optional
[
torch
.
Tensor
]
=
None
sampler_out_tenosr
:
Optional
[
torch
.
Tensor
]
=
None
sampler_out_ids
:
Optional
[
torch
.
Tensor
]
=
None
def
__getitem__
(
self
,
idx
:
int
)
->
CompletionSequenceGroupOutput
:
return
self
.
outputs
[
idx
]
...
...
@@ -167,7 +179,10 @@ class SamplerOutput(
f
"sampled_token_ids=
{
sampled_token_ids_repr
}
, "
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
, "
f
"logits=
{
self
.
logits
}
, "
f
"tree_attn_masks=
{
self
.
tree_attn_masks
}
)"
)
f
"tree_attn_masks=
{
self
.
tree_attn_masks
}
, "
f
"sampler_out_tenosr=
{
self
.
sampler_out_tenosr
}
, "
f
"sampler_out_ids=
{
self
.
sampler_out_ids
}
, "
f
")"
)
class
Sampler
(
nn
.
Module
):
...
...
@@ -199,6 +214,8 @@ class Sampler(nn.Module):
# speculative decoding.
self
.
include_gpu_probs_tensor
=
False
self
.
should_modify_greedy_probs_inplace
=
False
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
d2d_data
.
zero_overhead
=
self
.
zero_overhead
def
_init_sampling_tensors
(
self
,
...
...
@@ -295,7 +312,6 @@ class Sampler(nn.Module):
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Compute the log probabilities.
logprobs
=
torch
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Sample the next tokens.
maybe_deferred_sample_results
,
maybe_sampled_tokens_tensor
=
_sample
(
probs
,
...
...
@@ -460,7 +476,8 @@ def _greedy_sample(
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
samples_lst
=
samples
.
tolist
()
if
not
d2d_data
.
zero_overhead
:
samples_lst
=
samples
.
tolist
()
sample_idx
=
0
results
:
SampleResultType
=
[]
for
seq_group
in
selected_seq_groups
:
...
...
@@ -473,7 +490,11 @@ def _greedy_sample(
assert
num_parent_seqs
==
1
,
(
"Greedy sampling should have only one seq."
)
parent_ids
=
list
(
range
(
num_parent_seqs
))
next_token_ids
=
[
samples_lst
[
sample_idx
]]
if
d2d_data
.
zero_overhead
:
assert
num_parent_seqs
==
1
# not support muti seqences in seqence group
next_token_ids
=
[
0
]
#place holder token id
else
:
next_token_ids
=
[
samples_lst
[
sample_idx
]]
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
return
results
...
...
@@ -496,7 +517,8 @@ def _random_sample(
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum n value of the prompt phase requests.
random_samples
=
random_samples
.
cpu
()
if
not
d2d_data
.
zero_overhead
:
random_samples
=
random_samples
.
cpu
()
sample_idx
=
0
results
:
SampleResultType
=
[]
for
seq_group
in
selected_seq_groups
:
...
...
@@ -511,13 +533,21 @@ def _random_sample(
if
is_prompt
:
# Prompt phase.
parent_ids
=
[
0
]
*
sampling_params
.
n
next_token_ids
=
random_samples
[
sample_idx
,
:
sampling_params
.
n
].
tolist
()
if
d2d_data
.
zero_overhead
:
assert
num_parent_seqs
==
1
# not support muti seqences in seqence group
next_token_ids
=
[
0
]
*
sampling_params
.
n
#place holder token id
else
:
next_token_ids
=
random_samples
[
sample_idx
,
:
sampling_params
.
n
].
tolist
()
else
:
# Generation phase.
parent_ids
=
list
(
range
(
num_parent_seqs
))
next_token_ids
=
random_samples
[
sample_idx
:
sample_idx
+
num_parent_seqs
,
0
].
tolist
()
if
d2d_data
.
zero_overhead
:
assert
num_parent_seqs
==
1
# not support muti seqences in seqence group
next_token_ids
=
[
0
]
*
num_parent_seqs
#place holder token id
else
:
next_token_ids
=
random_samples
[
sample_idx
:
sample_idx
+
num_parent_seqs
,
0
].
tolist
()
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
return
results
...
...
@@ -689,7 +719,6 @@ def get_pythonized_sample_results(
sample_result_args
.
beam_search_logprobs
,
sample_result_args
.
sample_results_dict
,
)
for
sampling_type
in
SamplingType
:
if
sampling_type
not
in
sample_metadata
:
continue
...
...
@@ -734,12 +763,13 @@ def _sample_with_torch(
t
:
[]
for
t
in
SamplingType
}
d2d_data
.
seq_id
=
torch
.
zeros
(
len
(
sampling_metadata
.
seq_groups
),
dtype
=
torch
.
int32
)
categorized_sample_indices
=
sampling_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
d2d_data
.
seq_id
[
i
]
=
seq_group
.
seq_ids
[
0
]
sampling_params
=
seq_group
.
sampling_params
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
sample_results_dict
:
SampleResultsDictType
=
{}
sample_metadata
:
SampleMetadataType
=
{}
multinomial_samples
:
MultinomialSamplesType
=
{}
...
...
@@ -770,6 +800,9 @@ def _sample_with_torch(
if
sampling_type
==
SamplingType
.
GREEDY
:
greedy_samples
=
torch
.
argmax
(
logprobs
[
long_sample_indices
],
dim
=-
1
)
if
d2d_data
.
zero_overhead
:
d2d_data
.
sampled_token_ids_tensor
=
greedy_samples
.
unsqueeze
(
-
1
)
if
sampled_token_ids_tensor
is
not
None
:
# Store sampled tokens in output tensor.
...
...
@@ -807,6 +840,10 @@ def _sample_with_torch(
probs
[
long_sample_indices
],
max_n_in_batch
,
seq_groups
=
seq_groups_arg
)
if
d2d_data
.
zero_overhead
:
d2d_data
.
sampled_token_ids_tensor
=
\
multinomial_samples
[
sampling_type
].
to
(
torch
.
long
)
if
sampled_token_ids_tensor
is
not
None
:
# Store sampled tokens in output tensor.
...
...
@@ -1271,7 +1308,9 @@ def _build_sampler_output(
sampled_token_ids
=
sampled_token_ids
,
logprobs
=
logprobs_tensor
,
deferred_sample_results_args
=
deferred_sample_results_args
,
logits
=
logits
)
logits
=
logits
,
sampler_out_tenosr
=
d2d_data
.
sampled_token_ids_tensor
,
sampler_out_ids
=
d2d_data
.
seq_id
)
def
_get_next_prompt_tokens
(
seq_group
:
SequenceGroupToSample
)
->
List
[
int
]:
...
...
vllm/model_executor/layers/update_input.py
0 → 100644
View file @
645e9ec4
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_update_input_tokens
(
sample_output
,
seq_ids
,
input_tokens
,
input_seq_ids
,
BATCH_SIZE1
,
BATCH_SIZE2
,
):
pid
=
tl
.
program_id
(
0
)
if
pid
>=
BATCH_SIZE2
:
return
output_token
=
tl
.
load
(
input_tokens
+
pid
)
_input_seq_id
=
tl
.
load
(
input_seq_ids
+
pid
)
for
i
in
range
(
BATCH_SIZE1
):
_seq_ids
=
tl
.
load
(
seq_ids
+
i
)
if
_seq_ids
==
_input_seq_id
:
output_token
=
tl
.
load
(
sample_output
+
i
)
tl
.
store
(
input_tokens
+
pid
,
output_token
)
def
UpdateInputTokens
(
input_tokens
,
input_seq_ids
,
last_sample
,
last_ids
):
grid
=
[
input_seq_ids
.
shape
[
0
],
1
,
1
]
_update_input_tokens
[
grid
](
last_sample
,
last_ids
,
input_tokens
,
input_seq_ids
,
last_ids
.
shape
[
0
],
input_seq_ids
.
shape
[
0
])
\ No newline at end of file
vllm/model_executor/sampling_metadata.py
View file @
645e9ec4
...
...
@@ -514,7 +514,6 @@ class SamplingTensors:
pin_memory
=
is_pin_memory_available
()
do_penalties
=
prompt_tokens
or
output_tokens
if
do_penalties
:
prompt_t
=
make_tensor_with_pad
(
prompt_tokens
,
...
...
@@ -534,7 +533,6 @@ class SamplingTensors:
empty_tensor
=
torch
.
empty
(
0
,
device
=
device
,
dtype
=
torch
.
long
)
prompt_t
=
empty_tensor
output_t
=
empty_tensor
temperatures_t
=
torch
.
tensor
(
temperatures
,
device
=
"cpu"
,
...
...
vllm/profiler/prof.py
0 → 100644
View file @
645e9ec4
from
ctypes
import
*
import
os
import
time
import
threading
class
Prof
:
def
__init__
(
self
):
self
.
use_nvtx
=
os
.
getenv
(
'VLLM_PROF_NVTX'
)
is
not
None
self
.
roc_tracer_flag
=
False
self
.
lib
=
None
if
self
.
use_nvtx
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libnvToolsExt.so"
)
self
.
lib
.
nvtxRangePushA
.
argtypes
=
[
c_char_p
]
self
.
lib
.
nvtxRangePushA
.
restype
=
c_int
self
.
lib
.
nvtxRangePop
.
restype
=
c_int
self
.
use_roctx
=
os
.
getenv
(
'VLLM_PROF_ROCTX'
)
is
not
None
if
self
.
use_roctx
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libroctracer64.so"
)
self
.
lib
.
roctxRangePushA
.
argtypes
=
[
c_char_p
]
self
.
lib
.
roctxRangePushA
.
restype
=
c_int
self
.
lib
.
roctxRangePop
.
restype
=
c_int
self
.
tm
=
time
.
perf_counter
()
self
.
push_depth
=
{}
def
StartTracer
(
self
):
if
self
.
use_roctx
:
if
self
.
lib
is
None
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libroctracer64.so"
)
self
.
lib
.
roctracer_start
()
self
.
roc_tracer_flag
=
True
def
StopTracer
(
self
):
if
self
.
use_roctx
:
if
self
.
lib
is
None
:
self
.
lib
=
cdll
.
LoadLibrary
(
"libroctracer64.so"
)
self
.
lib
.
roctracer_stop
()
self
.
roc_tracer_flag
=
False
def
thread_depth_add
(
self
,
num
):
current_thread
=
threading
.
current_thread
()
thread_id
=
current_thread
.
ident
if
thread_id
not
in
self
.
push_depth
.
keys
():
self
.
push_depth
[
thread_id
]
=
0
if
num
<
0
and
self
.
push_depth
[
thread_id
]
==
0
:
return
False
self
.
push_depth
[
thread_id
]
+=
num
return
True
def
ProfRangePush
(
self
,
message
):
if
profile
.
use_nvtx
:
profile
.
lib
.
nvtxRangePushA
(
message
.
encode
(
'utf-8'
))
self
.
thread_depth_add
(
1
)
if
profile
.
use_roctx
and
self
.
roc_tracer_flag
:
profile
.
lib
.
roctxRangePushA
(
message
.
encode
(
'utf-8'
))
self
.
thread_depth_add
(
1
)
def
ProfRangePop
(
self
):
if
profile
.
use_nvtx
:
if
not
self
.
thread_depth_add
(
-
1
):
return
profile
.
lib
.
nvtxRangePop
()
if
profile
.
use_roctx
and
self
.
roc_tracer_flag
:
if
not
self
.
thread_depth_add
(
-
1
):
return
profile
.
lib
.
roctxRangePop
()
def
ProfRangeAutoPush
(
self
,
message
):
self
.
ProfRangePop
()
self
.
ProfRangePush
(
message
)
profile
=
Prof
()
vllm/sequence.py
View file @
645e9ec4
...
...
@@ -7,6 +7,7 @@ from array import array
from
collections
import
defaultdict
from
dataclasses
import
dataclass
,
field
from
functools
import
reduce
import
os
from
typing
import
Any
,
Callable
,
DefaultDict
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Union
...
...
@@ -178,6 +179,8 @@ class SequenceData(msgspec.Struct,
_first_step_flag
:
bool
=
True
_effective_length
:
int
=
0
@
staticmethod
def
from_prompt_token_counts
(
*
token_counts
:
Tuple
[
int
,
int
])
->
"SequenceData"
:
...
...
@@ -307,16 +310,31 @@ class SequenceData(msgspec.Struct,
self
.
_new_appended_tokens
.
append
(
token_id
)
self
.
_cached_all_token_ids
.
append
(
token_id
)
self
.
_cumulative_logprob
+=
logprob
def
fix_effective_token_id
(
self
,
token_id
:
int
,):
effect_offset
=
self
.
_effective_length
-
len
(
self
.
output_token_ids
)
if
effect_offset
<
0
:
self
.
_output_token_ids
[
effect_offset
]
=
token_id
if
len
(
self
.
_new_appended_tokens
)
>=
effect_offset
*
-
1
:
self
.
_new_appended_tokens
[
effect_offset
]
=
token_id
self
.
_cached_all_token_ids
[
effect_offset
]
=
token_id
self
.
_effective_length
+=
1
def
get_len
(
self
)
->
int
:
return
len
(
self
.
_output_token_ids
)
+
len
(
self
.
_prompt_token_ids
)
def
zero_overhead_get_len
(
self
)
->
int
:
return
self
.
_effective_length
+
len
(
self
.
_prompt_token_ids
)
def
get_prompt_len
(
self
)
->
int
:
return
len
(
self
.
_prompt_token_ids
)
def
get_output_len
(
self
)
->
int
:
return
len
(
self
.
_output_token_ids
)
def
zero_overhead_get_output_len
(
self
)
->
int
:
return
self
.
_effective_length
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
_cached_all_token_ids
...
...
@@ -367,15 +385,22 @@ class SequenceData(msgspec.Struct,
# of prompt_len here. This is because during recompute we need to
# prefill for both prompt and output.
return
self
.
get_len
()
-
self
.
get_num_computed_tokens
()
def
get_last_token_id
(
self
)
->
int
:
if
not
self
.
_output_token_ids
:
return
self
.
_prompt_token_ids
[
-
1
]
return
self
.
_output_token_ids
[
-
1
]
def
zero_overhead_get_last_token_id
(
self
)
->
int
:
if
self
.
_effective_length
==
0
:
return
self
.
_prompt_token_ids
[
-
1
]
return
self
.
_output_token_ids
[
self
.
_effective_length
-
1
]
def
get_prompt_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
prompt_token_ids
def
zero_overhead_get_output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
output_token_ids
[:
self
.
_effective_length
]
def
get_output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
output_token_ids
...
...
@@ -461,6 +486,7 @@ class Sequence:
self
.
read_offset
=
0
# Input + output tokens
self
.
tokens
:
Optional
[
List
[
str
]]
=
None
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
@
property
def
n_blocks
(
self
)
->
int
:
...
...
@@ -527,9 +553,9 @@ class Sequence:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if
not
delta
:
return
self
.
get_output_token_ids
()
return
self
.
get_output_token_ids
(
self
.
zero_overhead
)
output_len
=
self
.
get_output_len
()
output_len
=
self
.
get_output_len
(
self
.
zero_overhead
)
# Get the number of new tokens
num_new_tokens
=
output_len
-
self
.
_last_output_token_ids_offset
...
...
@@ -539,11 +565,16 @@ class Sequence:
if
num_new_tokens
==
1
:
# Optimization for single decode token case
# (which is what we have most of the time)
return
self
.
data
.
_cached_all_token_ids
[
-
1
]
if
self
.
zero_overhead
:
return
self
.
data
.
_cached_all_token_ids
[
self
.
data
.
_effective_length
-
1
]
else
:
return
self
.
data
.
_cached_all_token_ids
[
-
1
]
if
num_new_tokens
==
0
:
return
[]
if
self
.
zero_overhead
:
return
self
.
data
.
_cached_all_token_ids
[
-
num_new_tokens
:
self
.
data
.
_effective_length
]
return
self
.
data
.
_cached_all_token_ids
[
-
num_new_tokens
:]
def
hash_of_block
(
self
,
logical_idx
:
int
)
->
int
:
...
...
@@ -582,13 +613,20 @@ class Sequence:
self
.
output_logprobs
.
append
(
logprobs
)
self
.
data
.
append_token_id
(
token_id
,
logprobs
[
token_id
].
logprob
)
def
get_len
(
self
)
->
int
:
def
fix_last_token_id
(
self
,
token_id
:
int
)
->
None
:
self
.
data
.
fix_effective_token_id
(
token_id
)
def
get_len
(
self
,
zero_overhead
=
False
)
->
int
:
if
zero_overhead
:
return
self
.
data
.
zero_overhead_get_len
()
return
self
.
data
.
get_len
()
def
get_prompt_len
(
self
)
->
int
:
return
self
.
data
.
get_prompt_len
()
def
get_output_len
(
self
)
->
int
:
def
get_output_len
(
self
,
zero_overhead
=
False
)
->
int
:
if
zero_overhead
:
return
self
.
data
.
zero_overhead_get_output_len
()
return
self
.
data
.
get_output_len
()
def
get_token_ids
(
self
)
->
List
[
int
]:
...
...
@@ -597,10 +635,14 @@ class Sequence:
def
get_prompt_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
data
.
get_prompt_token_ids
()
def
get_last_token_id
(
self
)
->
int
:
def
get_last_token_id
(
self
,
zero_overhead
=
False
)
->
int
:
if
zero_overhead
:
return
self
.
data
.
zero_overhead_get_last_token_id
()
return
self
.
data
.
get_last_token_id
()
def
get_output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
def
get_output_token_ids
(
self
,
zero_overhead
=
False
)
->
Tuple
[
int
,
...]:
if
zero_overhead
:
return
self
.
data
.
zero_overhead_get_output_token_ids
()
return
self
.
data
.
get_output_token_ids
()
def
get_cumulative_logprob
(
self
)
->
float
:
...
...
@@ -807,17 +849,19 @@ class SequenceGroup:
def
set_last_token_time
(
self
,
now
:
float
)
->
None
:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, assertion fails.
assert
not
self
.
is_prefill
(),
(
"seq_group.set_last_token_time() should not be called "
"if the seq_group is in prefill phase."
)
if
not
self
.
seqs
[
0
].
zero_overhead
:
assert
not
self
.
is_prefill
(),
(
"seq_group.set_last_token_time() should not be called "
"if the seq_group is in prefill phase."
)
self
.
last_token_latency
=
now
-
self
.
metrics
.
last_token_time
self
.
metrics
.
last_token_time
=
now
def
get_last_token_latency
(
self
)
->
float
:
"""Returns the latency of the last token."""
assert
not
self
.
is_prefill
(),
(
"seq_group.get_last_token_latency() should not be called "
"if the seq_group is in prefill phase."
)
if
not
self
.
seqs
[
0
].
zero_overhead
:
assert
not
self
.
is_prefill
(),
(
"seq_group.get_last_token_latency() should not be called "
"if the seq_group is in prefill phase."
)
return
self
.
last_token_latency
def
maybe_set_first_token_time
(
self
,
time
:
float
)
->
None
:
...
...
@@ -1402,6 +1446,12 @@ class ExecuteModelRequest(
# Optional slot mapping of kvcache that pending to be moved generated from draft model.
kvcache_slot_to_be_moved
:
Optional
[
torch
.
Tensor
]
=
None
# for zero-overhead scheduler
last_outputs_sample
:
Optional
[
torch
.
Tensor
]
=
None
# for zero-overhead scheduler
last_outputs_ids
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
is_first_multi_step
(
self
)
->
bool
:
# TODO(will) make this be able to handle batches with variable number of
...
...
@@ -1451,7 +1501,9 @@ class ExecuteModelRequest(
async_callback
=
self
.
async_callback
,
tree_attn_masks
=
self
.
tree_attn_masks
,
tree_position_ids
=
self
.
tree_position_ids
,
kvcache_slot_to_be_moved
=
self
.
kvcache_slot_to_be_moved
)
kvcache_slot_to_be_moved
=
self
.
kvcache_slot_to_be_moved
,
last_outputs_sample
=
self
.
last_outputs_sample
,
last_outputs_ids
=
self
.
last_outputs_ids
)
@
dataclass
...
...
vllm/spec_decode/target_model_runner.py
View file @
645e9ec4
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
import
torch
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
...
...
@@ -31,10 +31,12 @@ class TargetModelRunner(ModelRunnerWrapperBase):
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
last_outputs_ids
:
torch
.
Tensor
=
None
,
last_output_sample
:
torch
.
Tensor
=
None
,
)
->
ModelRunnerInputBase
:
model_input
:
ModelRunnerInputBase
=
\
self
.
model_runner
.
prepare_model_input
(
seq_group_metadata_list
,
virtual_engine
,
finished_requests_ids
)
seq_group_metadata_list
,
virtual_engine
,
finished_requests_ids
,
last_outputs_ids
,
last_output_sample
)
# If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the
...
...
vllm/transformers_utils/detokenizer.py
View file @
645e9ec4
# SPDX-License-Identifier: Apache-2.0
import
os
from
typing
import
Dict
,
List
,
Optional
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
Logprob
,
SamplingParams
,
...
...
@@ -16,6 +17,7 @@ class Detokenizer:
def
__init__
(
self
,
tokenizer_group
:
BaseTokenizerGroup
):
self
.
tokenizer_group
=
tokenizer_group
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
def
get_tokenizer_for_seq
(
self
,
sequence
:
Sequence
)
->
AnyTokenizer
:
"""Returns the HF tokenizer to use for a given sequence."""
...
...
@@ -107,7 +109,11 @@ class Detokenizer:
Returns:
The number of characters added to the output text.
"""
all_input_ids
=
seq
.
get_token_ids
()
all_input_ids
=
seq
.
get_token_ids
()
if
self
.
zero_overhead
:
eff_length
=
seq
.
get_prompt_len
()
+
seq
.
data
.
_effective_length
all_input_ids
=
seq
.
get_token_ids
()[
:
eff_length
]
token_id_generated_this_iteration
=
all_input_ids
[
-
1
]
tokenizer
=
self
.
get_tokenizer_for_seq
(
seq
)
...
...
vllm/version.py
View file @
645e9ec4
# SPDX-License-Identifier: Apache-2.0
try
:
from
._version
import
__version__
,
__version_tuple__
__version__
=
"0.7.2"
__version_tuple__
=
(
0
,
7
,
2
)
__hcu_version__
=
f
'0.7.2+das.opt1.cust1.6b7651a.dtk2504'
from
vllm.version
import
__version__
,
__version_tuple__
,
__hcu_version__
except
Exception
as
e
:
import
warnings
warnings
.
warn
(
f
"Failed to read commit hash:
\n
{
e
}
"
,
warnings
.
warn
(
f
"Failed to read commit hash:
\n
+ str(e)
"
,
RuntimeWarning
,
stacklevel
=
2
)
__version__
=
"dev"
__version_tuple__
=
(
0
,
0
,
__version__
)
vllm/worker/model_runner.py
View file @
645e9ec4
...
...
@@ -5,6 +5,7 @@ import dataclasses
import
gc
import
inspect
import
itertools
import
os
import
time
import
weakref
from
contextlib
import
contextmanager
...
...
@@ -60,6 +61,8 @@ from vllm.worker.model_runner_base import (
_init_attn_metadata_from_tensor_dict
,
_init_sampling_metadata_from_tensor_dict
)
from
vllm.model_executor.layers.update_input
import
UpdateInputTokens
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
...
...
@@ -272,7 +275,6 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
computed_block_nums
=
computed_block_nums
self
.
n_seqs
=
n_seqs
self
.
encoder_seq_len
=
encoder_seq_len
if
reinit
:
if
len
(
self
.
seq_ids
)
==
1
and
reinit_use_defaults
:
self
.
simple_reinit
()
...
...
@@ -476,6 +478,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self
.
sliding_window_blocks
*
self
.
block_size
self
.
is_encoder_decoder_model
=
self
.
runner
.
model_config
.
is_encoder_decoder
self
.
zero_overhead
=
os
.
environ
.
get
(
'VLLM_ZERO_OVERHEAD'
)
==
'1'
self
.
last_sample_tensor
=
None
self
.
last_sample_ids
=
None
self
.
req_ids
=
[]
def
SetLastSamperData
(
self
,
last_sample_ids
,
last_sample_tensor
):
self
.
last_sample_tensor
=
last_sample_tensor
self
.
last_sample_ids
=
last_sample_ids
def
prepare
(
self
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
...
...
@@ -491,6 +501,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
ModelInputForGPUBuilder
.
InterDataForSeqGroup
]
=
[]
self
.
attn_metadata_builder
.
prepare
()
self
.
req_ids
.
clear
()
def
_compute_lens
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_group_metadata
:
SequenceGroupMetadata
):
...
...
@@ -756,8 +767,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
encoder_seq_len
=
encoder_seq_len
)
self
.
inter_data_list
.
append
(
inter_data
)
seq_ids
=
list
(
seq_ids
)
for
seq_idx
in
range
(
n_seqs
):
self
.
req_ids
.
append
(
seq_ids
[
seq_idx
])
for
per_seq_fn
in
self
.
per_seq_compute_fns
:
per_seq_fn
(
inter_data
,
seq_idx
,
seq_group_metadata
)
for
per_seq_group_fn
in
self
.
per_seq_group_compute_fns
:
...
...
@@ -898,9 +910,19 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if
cuda_graph_pad_size
:
input_tokens
.
extend
(
itertools
.
repeat
(
0
,
cuda_graph_pad_size
))
assert
self
.
runner
.
device
is
not
None
input_tokens_tensor
=
async_tensor_h2d
(
input_tokens
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
if
self
.
zero_overhead
and
self
.
last_sample_tensor
is
not
None
:
input_ids
=
async_tensor_h2d
(
self
.
req_ids
,
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
last_ids
=
async_tensor_h2d
(
self
.
last_sample_ids
.
tolist
(),
torch
.
long
,
self
.
runner
.
device
,
self
.
runner
.
pin_memory
)
UpdateInputTokens
(
input_tokens_tensor
,
input_ids
,
self
.
last_sample_tensor
,
last_ids
)
token_types_tensor
=
async_tensor_h2d
(
token_types
,
torch
.
long
,
self
.
runner
.
device
,
...
...
@@ -1203,7 +1225,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def
_prepare_model_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
last_outputs_ids
:
torch
.
Tensor
=
None
,
last_output_sample
:
torch
.
Tensor
=
None
,
)
->
TModelInputForGPU
:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
...
...
@@ -1224,7 +1248,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
builder
.
add_seq_group
(
seq_group_metadata
)
self
.
builder
.
reset_cached_inter_data
()
self
.
builder
.
SetLastSamperData
(
last_outputs_ids
,
last_output_sample
)
return
self
.
builder
.
build
()
# type: ignore
@
contextmanager
...
...
@@ -1619,6 +1643,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
last_outputs_ids
:
torch
.
Tensor
=
None
,
last_output_sample
:
torch
.
Tensor
=
None
,
)
->
ModelInputForGPUWithSamplingMetadata
:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
...
...
@@ -1634,7 +1660,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
If cuda graph is required, this API automatically pads inputs.
"""
model_input
=
self
.
_prepare_model_input_tensors
(
seq_group_metadata_list
,
finished_requests_ids
)
seq_group_metadata_list
,
finished_requests_ids
,
last_outputs_ids
,
last_output_sample
)
if
get_pp_group
().
is_last_rank
:
# Sampling metadata is only required for the final pp group
generators
=
self
.
get_generators
(
finished_requests_ids
)
...
...
@@ -1675,7 +1701,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self
.
set_active_prompt_adapters
(
model_input
.
prompt_adapter_requests
,
model_input
.
prompt_adapter_mapping
)
self
.
attn_state
.
begin_forward
(
model_input
)
# Currently cuda graph is only supported by the decode phase.
...
...
vllm/worker/model_runner_base.py
View file @
645e9ec4
...
...
@@ -210,6 +210,8 @@ class ModelRunnerBase(ABC, Generic[T]):
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
last_outputs_ids
:
torch
.
Tensor
=
None
,
last_output_sample
:
torch
.
Tensor
=
None
,
)
->
T
:
"""
Prepare the inputs to ModelRunnerBase.execute_model from an execution
...
...
vllm/worker/worker_base.py
View file @
645e9ec4
...
...
@@ -374,7 +374,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
self
.
model_runner
.
prepare_model_input
(
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
finished_requests_ids
))
execute_model_req
.
finished_requests_ids
,
last_outputs_ids
=
execute_model_req
.
last_outputs_ids
,
last_output_sample
=
execute_model_req
.
last_outputs_sample
))
if
self
.
tree_decoding
and
execute_model_req
.
tree_position_ids
is
not
None
and
\
execute_model_req
.
tree_attn_masks
is
not
None
:
...
...
@@ -462,7 +464,6 @@ class LocalOrDistributedWorkerBase(WorkerBase):
and
self
.
observability_config
.
collect_model_execute_time
):
orig_model_execute_time
=
intermediate_tensors
.
tensors
.
get
(
"model_execute_time"
,
torch
.
tensor
(
0
)).
item
()
output
=
self
.
model_runner
.
execute_model
(
model_input
=
model_input
,
kv_caches
=
self
.
kv_cache
[
worker_input
.
virtual_engine
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment