Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
31330101
Commit
31330101
authored
Apr 16, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.4' into v0.8.4-dev
parents
e8933c34
dc1b4a6f
Changes
346
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
523 additions
and
119 deletions
+523
-119
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+9
-3
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+45
-0
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+25
-1
vllm/worker/hpu_model_runner.py
vllm/worker/hpu_model_runner.py
+314
-113
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+8
-2
vllm/worker/multi_step_hpu_worker.py
vllm/worker/multi_step_hpu_worker.py
+122
-0
No files found.
vllm/v1/worker/tpu_worker.py
View file @
31330101
...
...
@@ -157,13 +157,19 @@ class TPUWorker:
runner_kv_caches
)
self
.
model_runner
.
_dummy_run
(
runner_kv_caches
,
num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
,
)
self
.
scheduler_config
.
max_num_batched_tokens
)
# Synchronize before measuring the memory usage.
xm
.
wait_device_ops
()
# During the profiling run, the model runs without KV cache. After
# the profiling run, the model always runs with KV cache. Here we clear
# the dynamo cache and cached bytecode to ensure the model always has
# one compiled bytecode. Having one FX graph/cached bytecode per
# compiled model is required for `support_torch_compile` decorator to
# skip dynamo guard.
self
.
model_runner
.
reset_dynamo_cache
()
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
m
=
xm
.
get_memory_info
(
self
.
device
)
...
...
vllm/v1/worker/utils.py
View file @
31330101
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
torch
...
...
@@ -27,3 +29,46 @@ def sanity_check_mm_encoder_outputs(
f
"but got tensors with shapes
{
[
e
.
shape
for
e
in
mm_embeddings
]
}
"
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method."
)
def
scatter_mm_placeholders
(
embeds
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""
Scatter the multimodal embeddings into a contiguous tensor that represents
the placeholder tokens.
:class:`vllm.multimodal.processing.PromptUpdateDetails.is_embed`.
Args:
embeds: The multimodal embeddings.
Shape: `(num_embeds, embed_dim)`
is_embed: A boolean mask indicating which positions in the placeholder
tokens need to be filled with multimodal embeddings.
Shape: `(num_placeholders, num_embeds)`
"""
if
is_embed
is
None
:
return
embeds
placeholders
=
embeds
.
new_full
(
(
is_embed
.
shape
[
0
],
embeds
.
shape
[
-
1
]),
fill_value
=
torch
.
nan
,
)
placeholders
[
is_embed
]
=
embeds
return
placeholders
def
gather_mm_placeholders
(
placeholders
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""
Reconstructs the embeddings from the placeholder tokens.
This is the operation of :func:`scatter_mm_placeholders`.
"""
if
is_embed
is
None
:
return
placeholders
return
placeholders
[
is_embed
]
vllm/worker/enc_dec_model_runner.py
View file @
31330101
...
...
@@ -16,6 +16,7 @@ from vllm.config import VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
...
...
@@ -34,6 +35,7 @@ from vllm.worker.model_runner_base import (
from
vllm.worker.utils
import
assert_enc_dec_mr_supported_scenario
logger
=
init_logger
(
__name__
)
LORA_WARMUP_RANK
=
8
@
dataclasses
.
dataclass
(
frozen
=
True
)
...
...
@@ -160,7 +162,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
if
num_steps
>
1
:
raise
ValueError
(
"num_steps > 1 is not supported in "
"EncoderDecoderModelRunner"
)
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
if
(
model_input
.
attn_metadata
is
not
None
and
model_input
.
attn_metadata
.
prefill_metadata
is
None
and
model_input
.
attn_metadata
.
decode_metadata
.
use_cuda_graph
):
...
...
@@ -268,6 +274,22 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, and therefore the max amount of
# memory consumption. Create dummy lora request copies from the
# lora request passed in, which contains a lora from the lora
# warmup path.
dummy_lora_requests
:
List
[
LoRARequest
]
=
[]
dummy_lora_requests_per_seq
:
List
[
LoRARequest
]
=
[]
if
self
.
lora_config
:
dummy_lora_requests
=
self
.
_add_dummy_loras
(
self
.
lora_config
.
max_loras
)
assert
len
(
dummy_lora_requests
)
==
self
.
lora_config
.
max_loras
dummy_lora_requests_per_seq
=
[
dummy_lora_requests
[
idx
%
len
(
dummy_lora_requests
)]
for
idx
in
range
(
max_num_seqs
)
]
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
...
...
@@ -315,6 +337,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
block_tables
=
None
,
encoder_seq_data
=
encoder_dummy_data
.
seq_data
,
cross_block_table
=
None
,
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
if
dummy_lora_requests_per_seq
else
None
,
multi_modal_data
=
decoder_dummy_data
.
multi_modal_data
or
encoder_dummy_data
.
multi_modal_data
,
multi_modal_placeholders
=
decoder_dummy_data
.
...
...
vllm/worker/hpu_model_runner.py
View file @
31330101
...
...
@@ -32,6 +32,7 @@ from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
DeviceConfig
,
VllmConfig
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.distributed.parallel_state
import
get_world_group
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
...
...
@@ -44,11 +45,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SequenceGroupToSample
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalKwargs
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
Logprob
,
SequenceData
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.utils
import
(
bind_kv_cache
,
is_pin_memory_available
,
make_tensor_with_pad
)
from
vllm.worker.model_runner_base
import
(
...
...
@@ -100,7 +103,10 @@ def subtuple(obj: object,
if
to_override
is
None
:
to_override
=
{}
fields
=
set
(
to_copy
)
|
set
(
to_override
.
keys
())
values
=
{
f
:
to_override
.
get
(
f
,
getattr
(
obj
,
f
))
for
f
in
fields
}
if
type
(
obj
)
is
dict
:
values
=
{
key
:
obj
[
key
]
for
key
in
fields
if
key
in
obj
}
else
:
values
=
{
f
:
to_override
.
get
(
f
,
getattr
(
obj
,
f
))
for
f
in
fields
}
if
typename
not
in
_TYPE_CACHE
:
_TYPE_CACHE
[
typename
]
=
collections
.
namedtuple
(
typename
,
' '
.
join
(
fields
))
...
...
@@ -533,6 +539,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
virtual_engine
:
int
=
0
lora_ids
:
Optional
[
List
[
int
]]
=
None
async_callback
:
Optional
[
Callable
]
=
None
is_first_multi_step
:
bool
=
True
is_last_step
:
bool
=
True
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
...
...
@@ -545,6 +553,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
"batch_size_padded"
:
self
.
batch_size_padded
,
"virtual_engine"
:
self
.
virtual_engine
,
"lora_ids"
:
self
.
lora_ids
,
"is_first_multi_step"
:
self
.
is_first_multi_step
,
"is_last_step"
:
self
.
is_last_step
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
return
tensor_dict
...
...
@@ -656,6 +666,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self
.
_set_gc_threshold
()
self
.
use_contiguous_pa
=
envs
.
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
# For multi-step scheduling
self
.
cached_step_outputs
:
List
[
torch
.
Tensor
]
=
[]
def
_set_gc_threshold
(
self
)
->
None
:
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
# for comprehensive description of gc generations.
...
...
@@ -1005,6 +1018,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
def
_prepare_decode
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
output
=
None
,
)
->
PrepareDecodeMetadata
:
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
...
...
@@ -1035,8 +1049,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
([
generation_token
])
if
output
is
None
:
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
([
generation_token
])
seq_len
=
seq_data
.
get_len
()
position
=
seq_len
-
1
...
...
@@ -1047,6 +1062,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
seq_lens
.
append
(
seq_len
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
num_fully_occupied_blocks
=
position
//
self
.
block_size
block_table
=
block_table
[:
num_fully_occupied_blocks
+
1
]
if
len
(
block_table
)
==
0
:
block_number
=
_PAD_BLOCK_ID
else
:
...
...
@@ -1066,9 +1084,14 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
if
output
is
None
:
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
else
:
real_batch_size
=
len
(
seq_group_metadata_list
)
input_tokens
=
output
[:
real_batch_size
]
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
...
...
@@ -1462,7 +1485,27 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
profiler
.
start
()
for
_
in
range
(
times
):
inputs
=
self
.
prepare_model_input
(
seqs
)
self
.
execute_model
(
inputs
,
None
,
warmup_mode
=
True
)
is_single_step
=
\
self
.
vllm_config
.
scheduler_config
.
num_scheduler_steps
==
1
if
is_prompt
or
is_single_step
:
self
.
execute_model
(
inputs
,
None
,
warmup_mode
=
True
)
else
:
# decode with multi-step
inputs
=
dataclasses
.
replace
(
inputs
,
is_first_multi_step
=
True
,
is_last_step
=
False
)
self
.
execute_model
(
inputs
,
None
,
warmup_mode
=
True
,
num_steps
=
2
,
seqs
=
seqs
)
inputs
=
dataclasses
.
replace
(
inputs
,
is_first_multi_step
=
False
,
is_last_step
=
True
)
self
.
execute_model
(
inputs
,
None
,
warmup_mode
=
True
,
num_steps
=
2
,
seqs
=
seqs
)
torch
.
hpu
.
synchronize
()
if
profiler
:
profiler
.
step
()
...
...
@@ -1985,115 +2028,273 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
warmup_mode
=
False
,
seqs
=
None
,
)
->
Optional
[
Union
[
List
[
SamplerOutput
],
IntermediateTensors
]]:
if
num_steps
>
1
:
raise
ValueError
(
"num_steps > 1 is not supported in HPUModelRunner"
)
if
not
model_input
.
is_first_multi_step
:
if
not
model_input
.
is_last_step
:
# not first or last multi-step
return
[]
# last multi-step
output
=
self
.
_decode_sampler_outputs
(
model_input
)
if
self
.
is_driver_worker
else
[]
torch
.
hpu
.
synchronize
()
if
model_input
.
is_first_multi_step
:
# first multi-step
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
input_tokens
=
model_input
.
input_tokens
input_positions
=
model_input
.
input_positions
attn_metadata
=
model_input
.
attn_metadata
sampling_metadata
=
model_input
.
sampling_metadata
real_batch_size
=
model_input
.
real_batch_size
batch_size_padded
=
model_input
.
batch_size_padded
assert
input_tokens
is
not
None
assert
input_positions
is
not
None
assert
sampling_metadata
is
not
None
assert
attn_metadata
is
not
None
is_prompt
=
attn_metadata
.
is_prompt
assert
is_prompt
is
not
None
batch_size
=
input_tokens
.
size
(
0
)
seq_len
=
self
.
_seq_len
(
attn_metadata
)
use_graphs
=
self
.
_use_graphs
(
batch_size
,
seq_len
,
is_prompt
)
self
.
_check_config
(
batch_size
,
seq_len
,
is_prompt
,
warmup_mode
)
lora_mask
:
torch
.
Tensor
=
None
lora_logits_mask
:
torch
.
Tensor
=
None
if
self
.
lora_config
:
assert
model_input
.
lora_ids
is
not
None
lora_mask
,
lora_logits_mask
=
self
.
create_lora_mask
(
input_tokens
,
model_input
.
lora_ids
,
attn_metadata
.
is_prompt
)
execute_model_kwargs
=
{
"input_ids"
:
input_tokens
,
"positions"
:
input_positions
,
"attn_metadata"
:
self
.
trim_attn_metadata
(
attn_metadata
),
"intermediate_tensors"
:
intermediate_tensors
,
"lora_mask"
:
lora_mask
,
"virtual_engine"
:
model_input
.
virtual_engine
,
**
(
model_input
.
multi_modal_kwargs
or
{}),
}
if
htorch
.
utils
.
internal
.
is_lazy
():
execute_model_kwargs
.
update
(
{
"bypass_hpu_graphs"
:
not
use_graphs
})
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
input_tokens
=
model_input
.
input_tokens
input_positions
=
model_input
.
input_positions
attn_metadata
=
model_input
.
attn_metadata
sampling_metadata
=
model_input
.
sampling_metadata
real_batch_size
=
model_input
.
real_batch_size
batch_size_padded
=
model_input
.
batch_size_padded
assert
input_tokens
is
not
None
assert
input_positions
is
not
None
assert
sampling_metadata
is
not
None
assert
attn_metadata
is
not
None
is_prompt
=
attn_metadata
.
is_prompt
assert
is_prompt
is
not
None
batch_size
=
input_tokens
.
size
(
0
)
seq_len
=
self
.
_seq_len
(
attn_metadata
)
use_graphs
=
self
.
_use_graphs
(
batch_size
,
seq_len
,
is_prompt
)
self
.
_check_config
(
batch_size
,
seq_len
,
is_prompt
,
warmup_mode
)
htorch
.
core
.
mark_step
()
if
self
.
is_driver_worker
:
model_event_name
=
(
"model_"
f
"
{
'prompt'
if
is_prompt
else
'decode'
}
_"
f
"bs
{
batch_size
}
_"
f
"seq
{
seq_len
}
_"
f
"graphs
{
'T'
if
use_graphs
else
'F'
}
"
)
else
:
model_event_name
=
'model_executable'
if
num_steps
>
1
:
# in case of multi-step scheduling
# we only want to pythonize in the last step
sampling_metadata
.
skip_sampler_cpu_output
=
True
self
.
model
.
model
.
sampler
.
include_gpu_probs_tensor
=
True
cache_orig_output_tokens_len
:
List
[
Dict
]
=
[]
def
try_revert_dummy_output_tokens
():
if
len
(
cache_orig_output_tokens_len
)
>
0
:
# Reuse the original output token ids length
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
for
j
,
data
in
seq_group_metadata
.
seq_data
.
items
():
orig_output_tokens_len
=
\
cache_orig_output_tokens_len
[
i
][
j
]
data
.
output_token_ids
=
\
data
.
output_token_ids
[:
orig_output_tokens_len
]
for
i
in
range
(
num_steps
):
if
i
!=
0
and
not
self
.
is_driver_worker
:
broadcast_data
=
broadcast_tensor_dict
(
src
=
0
)
if
'early_exit'
in
broadcast_data
and
broadcast_data
[
'early_exit'
]:
return
[
output
]
if
num_steps
==
1
else
[]
execute_model_kwargs
.
update
({
"input_ids"
:
broadcast_data
[
"input_ids"
],
"positions"
:
broadcast_data
[
"positions"
],
"attn_metadata"
:
self
.
trim_attn_metadata
(
broadcast_data
[
"attn_metadata"
])
})
with
self
.
profiler
.
record_event
(
'internal'
,
model_event_name
):
hidden_states
=
self
.
model
.
forward
(
**
execute_model_kwargs
,
selected_token_indices
=
sampling_metadata
.
selected_token_indices
)
if
self
.
lora_config
:
LoraMask
.
setLoraMask
(
lora_logits_mask
.
index_select
(
0
,
sampling_metadata
.
selected_token_indices
))
# Compute the logits.
with
self
.
profiler
.
record_event
(
'internal'
,
(
'compute_logits_'
f
'
{
"prompt"
if
is_prompt
else
"decode"
}
_bs'
f
'
{
batch_size
}
_'
f
'seq
{
seq_len
}
'
)):
if
num_steps
==
1
:
sampling_metadata
.
selected_token_indices
=
None
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
htorch
.
core
.
mark_step
()
# Only perform sampling in the driver worker.
if
not
self
.
is_driver_worker
:
continue
lora_mask
:
torch
.
Tensor
=
None
lora_logits_mask
:
torch
.
Tensor
=
None
if
self
.
lora_config
:
assert
model_input
.
lora_ids
is
not
None
lora_mask
,
lora_logits_mask
=
self
.
create_lora_mask
(
input_tokens
,
model_input
.
lora_ids
,
attn_metadata
.
is_prompt
)
execute_model_kwargs
=
{
"input_ids"
:
input_tokens
,
"positions"
:
input_positions
,
"attn_metadata"
:
self
.
trim_attn_metadata
(
attn_metadata
),
"intermediate_tensors"
:
intermediate_tensors
,
"lora_mask"
:
lora_mask
,
"virtual_engine"
:
model_input
.
virtual_engine
,
**
(
model_input
.
multi_modal_kwargs
or
{}),
}
if
htorch
.
utils
.
internal
.
is_lazy
():
execute_model_kwargs
.
update
({
"bypass_hpu_graphs"
:
not
use_graphs
})
htorch
.
core
.
mark_step
()
if
self
.
is_driver_worker
:
model_event_name
=
(
"model_"
f
"
{
'prompt'
if
is_prompt
else
'decode'
}
_"
f
"bs
{
batch_size
}
_"
f
"seq
{
seq_len
}
_"
f
"graphs
{
'T'
if
use_graphs
else
'F'
}
"
)
if
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
# Sample the next token.
with
self
.
profiler
.
record_event
(
'internal'
,
(
'sample_'
f
'
{
"prompt"
if
is_prompt
else
"decode"
}
_'
f
'bs
{
batch_size
}
_'
f
'seq
{
seq_len
}
'
)):
output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
)
if
num_steps
>
1
:
output
=
output
.
sampled_token_ids
self
.
cached_step_outputs
.
append
(
output
.
detach
().
clone
())
htorch
.
core
.
mark_step
()
if
i
<
num_steps
-
1
:
if
i
==
0
:
if
model_input
.
async_callback
is
not
None
:
ctx
=
model_input
.
async_callback
.
keywords
[
# type: ignore
"ctx"
]
seq_group_metadata_list
=
\
ctx
.
seq_group_metadata_list
elif
seqs
is
not
None
:
seq_group_metadata_list
=
seqs
else
:
raise
RuntimeError
(
"seq_group_metadata_list is uninitialized"
)
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
# Skip empty steps
seq_group_metadata
.
state
.
current_step
+=
(
num_steps
-
2
)
# Cache the original output token ids
cache_orig_output_tokens_len
.
append
({})
for
j
,
data
in
seq_group_metadata
.
seq_data
.
items
():
cache_orig_output_tokens_len
[
i
][
j
]
=
\
len
(
data
.
output_token_ids
)
for
seq_group_metadata
in
seq_group_metadata_list
:
for
data
in
seq_group_metadata
.
seq_data
.
values
():
max_output_len
=
sampling_metadata
.
seq_groups
[
0
].
sampling_params
.
max_tokens
if
len
(
data
.
output_token_ids
)
<
max_output_len
-
1
:
# add a place holder for prepare_decode
# arbitrary value, this could be any token
dummy_token
=
(
540
,
)
data
.
output_token_ids
+=
(
dummy_token
)
else
:
broadcast_tensor_dict
({
'early_exit'
:
True
},
src
=
0
)
if
num_steps
==
1
:
return
[
output
]
else
:
try_revert_dummy_output_tokens
()
return
[]
result
=
self
.
_prepare_decode
(
seq_group_metadata_list
,
output
=
output
)
execute_model_kwargs
.
update
({
"input_ids"
:
result
.
input_tokens
,
"positions"
:
result
.
input_positions
,
"attn_metadata"
:
self
.
trim_attn_metadata
(
result
.
attn_metadata
)
})
model_kwargs_broadcast_data
=
{
"input_ids"
:
result
.
input_tokens
,
"positions"
:
result
.
input_positions
,
"attn_metadata"
:
vars
(
result
.
attn_metadata
)
}
broadcast_tensor_dict
(
model_kwargs_broadcast_data
,
src
=
0
)
else
:
try_revert_dummy_output_tokens
()
if
self
.
is_driver_worker
and
self
.
profiler
.
enabled
:
# Stop recording 'execute_model' event
self
.
profiler
.
end
()
event_end
=
self
.
profiler
.
get_timestamp_us
()
counters
=
self
.
profiler_counter_helper
.
get_counter_dict
(
cache_config
=
self
.
cache_config
,
duration
=
event_end
-
self
.
event_start
,
seq_len
=
seq_len
,
batch_size_padded
=
batch_size_padded
,
real_batch_size
=
real_batch_size
,
is_prompt
=
is_prompt
)
self
.
profiler
.
record_counter
(
self
.
event_start
,
counters
)
if
num_steps
==
1
:
return
[
output
]
if
self
.
is_driver_worker
else
[]
else
:
return
[]
return
output
if
type
(
output
)
is
list
else
[
output
]
def
_decode_sampler_outputs
(
self
,
model_input
):
use_async_out_proc
=
model_input
.
async_callback
is
not
None
sampler_outputs
=
[]
num_outputs
=
len
(
self
.
cached_step_outputs
)
for
i
in
range
(
num_outputs
):
next_token_ids
=
self
.
cached_step_outputs
.
pop
(
0
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
sampler_output
=
self
.
_make_decode_output
(
next_token_ids
,
model_input
.
sampling_metadata
.
seq_groups
)
sampler_outputs
.
append
(
sampler_output
)
if
i
<
num_outputs
-
1
and
use_async_out_proc
:
assert
model_input
.
async_callback
is
not
None
ctx
=
model_input
.
async_callback
.
keywords
[
# type: ignore
"ctx"
]
ctx
.
append_output
(
outputs
=
[
sampler_output
],
seq_group_metadata_list
=
ctx
.
seq_group_metadata_list
,
scheduler_outputs
=
ctx
.
scheduler_outputs
,
is_async
=
False
,
is_last_step
=
False
,
is_first_step_output
=
False
)
model_input
.
async_callback
()
if
use_async_out_proc
:
return
[
sampler_outputs
[
-
1
]]
else
:
model_event_name
=
'model_executable'
with
self
.
profiler
.
record_event
(
'internal'
,
model_event_name
):
hidden_states
=
self
.
model
.
forward
(
**
execute_model_kwargs
,
selected_token_indices
=
sampling_metadata
.
selected_token_indices
)
return
sampler_outputs
if
self
.
lora_config
:
LoraMask
.
setLoraMask
(
lora_logits_mask
.
index_select
(
0
,
sampling_metadata
.
selected_token_indices
))
# Compute the logits.
with
self
.
profiler
.
record_event
(
'internal'
,
(
'compute_logits_'
f
'
{
"prompt"
if
is_prompt
else
"decode"
}
_bs'
f
'
{
batch_size
}
_'
f
'seq
{
seq_len
}
'
)):
sampling_metadata
.
selected_token_indices
=
None
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
htorch
.
core
.
mark_step
()
# Only perform sampling in the driver worker.
if
not
self
.
is_driver_worker
:
return
[]
if
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
# Sample the next token.
with
self
.
profiler
.
record_event
(
'internal'
,
(
'sample_'
f
'
{
"prompt"
if
is_prompt
else
"decode"
}
_'
f
'bs
{
batch_size
}
_'
f
'seq
{
seq_len
}
'
)):
output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
)
output
.
outputs
=
output
.
outputs
[:
real_batch_size
]
htorch
.
core
.
mark_step
()
if
self
.
is_driver_worker
and
self
.
profiler
.
enabled
:
# Stop recording 'execute_model' event
self
.
profiler
.
end
()
event_end
=
self
.
profiler
.
get_timestamp_us
()
counters
=
self
.
profiler_counter_helper
.
get_counter_dict
(
cache_config
=
self
.
cache_config
,
duration
=
event_end
-
self
.
event_start
,
seq_len
=
seq_len
,
batch_size_padded
=
batch_size_padded
,
real_batch_size
=
real_batch_size
,
is_prompt
=
is_prompt
)
self
.
profiler
.
record_counter
(
self
.
event_start
,
counters
)
return
[
output
]
def
_make_decode_output
(
self
,
next_token_ids
:
List
[
List
[
int
]],
seq_groups
:
List
[
SequenceGroupToSample
],
)
->
SamplerOutput
:
zero_logprob
=
Logprob
(
0.0
)
sampler_outputs
=
[]
batch_idx
=
0
for
seq_group
in
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
seq_outputs
=
[]
for
seq_id
in
seq_ids
:
next_token_id
=
next_token_ids
[
batch_idx
][
0
]
seq_outputs
.
append
(
SequenceOutput
(
seq_id
,
next_token_id
,
{
next_token_id
:
zero_logprob
}))
batch_idx
+=
1
sampler_outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
return
SamplerOutput
(
sampler_outputs
)
def
shutdown_inc
(
self
):
can_finalize_inc
=
False
...
...
vllm/worker/model_runner.py
View file @
31330101
# SPDX-License-Identifier: Apache-2.0
import
sys
import
dataclasses
import
gc
import
inspect
...
...
@@ -15,7 +16,7 @@ import numpy as np
import
torch
import
torch.distributed
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
.auto
import
tqdm
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
...
...
@@ -1108,6 +1109,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if
hasattr
(
self
,
"_builder_cls"
):
# multi-step model runner does not have `_builder_cls`
self
.
builder
=
self
.
_builder_cls
(
weakref
.
proxy
(
self
))
self
.
enforce_eager_bs_threshould
=
sys
.
maxsize
if
envs
.
VLLM_ENFORCE_EAGER_BS_THRESHOLD
is
not
None
and
envs
.
VLLM_ENFORCE_EAGER_BS_THRESHOLD
>
0
:
self
.
enforce_eager_bs_threshould
=
envs
.
VLLM_ENFORCE_EAGER_BS_THRESHOLD
def
load_model
(
self
)
->
None
:
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
...
...
@@ -1717,7 +1722,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
# virtual engines share the same kv cache.
virtual_engine
=
model_input
.
virtual_engine
previous_hidden_states
=
kwargs
.
get
(
"previous_hidden_states"
)
if
prefill_meta
is
None
and
decode_meta
.
use_cuda_graph
:
if
prefill_meta
is
None
and
decode_meta
.
use_cuda_graph
and
\
model_input
.
input_tokens
.
shape
[
0
]
<=
self
.
enforce_eager_bs_threshould
:
assert
model_input
.
input_tokens
is
not
None
graph_batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
model_executable
=
self
.
graph_runners
[
virtual_engine
][
...
...
vllm/worker/multi_step_hpu_worker.py
0 → 100644
View file @
31330101
# SPDX-License-Identifier: Apache-2.0
###############################################################################
# Copyright (C) 2025 Habana Labs, Ltd. an Intel Company
###############################################################################
import
dataclasses
from
typing
import
Dict
,
Optional
,
Tuple
import
torch
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.worker.hpu_model_runner
import
ModelInputForHPU
from
vllm.worker.hpu_worker
import
HPUWorker
from
vllm.worker.worker_base
import
WorkerInput
class
MultiStepHPUWorker
(
HPUWorker
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
cached_model_input
:
Optional
[
ModelInputForHPU
]
=
None
def
_get_driver_input_and_broadcast
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
Tuple
[
ModelInputForHPU
,
WorkerInput
,
Dict
[
str
,
torch
.
Tensor
]]:
"""
Get the driver input and broadcast it to other workers.
"""
assert
self
.
is_driver_worker
assert
execute_model_req
.
virtual_engine
==
0
is_first_multi_step
=
execute_model_req
.
is_first_multi_step
is_last_step
=
execute_model_req
.
is_last_step
if
is_first_multi_step
:
# on first step we prepare the worker input and model input normally
worker_input
:
WorkerInput
=
self
.
prepare_worker_input
(
execute_model_req
=
execute_model_req
)
worker_input
=
dataclasses
.
replace
(
worker_input
,
num_steps
=
execute_model_req
.
num_lookahead_slots
+
1
)
model_input
:
ModelInputForHPU
=
(
self
.
model_runner
.
prepare_model_input
(
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
finished_requests_ids
))
if
execute_model_req
.
async_callback
:
model_input
=
dataclasses
.
replace
(
model_input
,
async_callback
=
execute_model_req
.
async_callback
)
else
:
# on subsequent steps we reuse the worker input and model input
assert
self
.
cached_model_input
is
not
None
model_input
=
self
.
cached_model_input
worker_input
=
WorkerInput
()
model_input
=
dataclasses
.
replace
(
model_input
,
is_first_multi_step
=
is_first_multi_step
,
is_last_step
=
is_last_step
)
if
self
.
do_metadata_broadcast
:
if
is_first_multi_step
:
broadcast_data
=
worker_input
.
as_broadcastable_tensor_dict
()
broadcast_data
.
update
(
model_input
.
as_broadcastable_tensor_dict
())
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
else
:
broadcast_data
=
{
"is_first_multi_step"
:
is_first_multi_step
,
"is_last_step"
:
is_last_step
,
}
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
# Returning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return
model_input
,
worker_input
,
{}
def
prepare_input
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
)
->
Optional
[
Tuple
[
ModelInputForHPU
,
WorkerInput
,
Dict
[
str
,
torch
.
Tensor
]]]:
if
self
.
is_driver_worker
:
if
execute_model_req
is
None
:
if
self
.
do_metadata_broadcast
:
# This signals that there's no more requests to process for
# now. All workers are running infinite loop with
# broadcast_tensor_dict, and it stops the loop when the
# driver broadcasts an empty input. Send an empty input to
# notify all other workers to stop their execution loop.
broadcast_tensor_dict
({},
src
=
0
)
return
None
model_input
,
worker_input
,
_
=
self
.
_get_driver_input_and_broadcast
(
execute_model_req
)
if
model_input
.
is_first_multi_step
:
self
.
cached_model_input
=
model_input
return
model_input
,
worker_input
,
{}
else
:
broadcast_data
=
broadcast_tensor_dict
(
src
=
0
)
if
not
broadcast_data
:
return
None
if
len
(
broadcast_data
)
==
2
:
assert
self
.
cached_model_input
is
not
None
self
.
cached_model_input
=
dataclasses
.
replace
(
self
.
cached_model_input
,
is_first_multi_step
=
broadcast_data
[
"is_first_multi_step"
],
is_last_step
=
broadcast_data
[
"is_last_step"
])
empty_worker_input
=
WorkerInput
()
return
self
.
cached_model_input
,
empty_worker_input
,
{}
worker_input
=
WorkerInput
.
from_broadcasted_tensor_dict
(
broadcast_data
)
model_input
=
(
self
.
model_runner
.
make_model_input_from_broadcasted_tensor_dict
(
broadcast_data
))
self
.
cached_model_input
=
model_input
return
model_input
,
worker_input
,
{}
Prev
1
…
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment