Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
31330101
Commit
31330101
authored
Apr 16, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.4' into v0.8.4-dev
parents
e8933c34
dc1b4a6f
Changes
346
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
523 additions
and
119 deletions
+523
-119
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+9
-3
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+45
-0
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+25
-1
vllm/worker/hpu_model_runner.py
vllm/worker/hpu_model_runner.py
+314
-113
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+8
-2
vllm/worker/multi_step_hpu_worker.py
vllm/worker/multi_step_hpu_worker.py
+122
-0
No files found.
vllm/v1/worker/tpu_worker.py
View file @
31330101
...
@@ -157,13 +157,19 @@ class TPUWorker:
...
@@ -157,13 +157,19 @@ class TPUWorker:
runner_kv_caches
)
runner_kv_caches
)
self
.
model_runner
.
_dummy_run
(
self
.
model_runner
.
_dummy_run
(
runner_kv_caches
,
self
.
scheduler_config
.
max_num_batched_tokens
)
num_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
,
)
# Synchronize before measuring the memory usage.
# Synchronize before measuring the memory usage.
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
# During the profiling run, the model runs without KV cache. After
# the profiling run, the model always runs with KV cache. Here we clear
# the dynamo cache and cached bytecode to ensure the model always has
# one compiled bytecode. Having one FX graph/cached bytecode per
# compiled model is required for `support_torch_compile` decorator to
# skip dynamo guard.
self
.
model_runner
.
reset_dynamo_cache
()
# Get the maximum amount of memory used by the model weights and
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
# intermediate activations.
m
=
xm
.
get_memory_info
(
self
.
device
)
m
=
xm
.
get_memory_info
(
self
.
device
)
...
...
vllm/v1/worker/utils.py
View file @
31330101
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
import
torch
import
torch
...
@@ -27,3 +29,46 @@ def sanity_check_mm_encoder_outputs(
...
@@ -27,3 +29,46 @@ def sanity_check_mm_encoder_outputs(
f
"but got tensors with shapes
{
[
e
.
shape
for
e
in
mm_embeddings
]
}
"
f
"but got tensors with shapes
{
[
e
.
shape
for
e
in
mm_embeddings
]
}
"
"instead. This is most likely due to incorrect implementation "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method."
)
"of the model's `get_multimodal_embeddings` method."
)
def
scatter_mm_placeholders
(
embeds
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""
Scatter the multimodal embeddings into a contiguous tensor that represents
the placeholder tokens.
:class:`vllm.multimodal.processing.PromptUpdateDetails.is_embed`.
Args:
embeds: The multimodal embeddings.
Shape: `(num_embeds, embed_dim)`
is_embed: A boolean mask indicating which positions in the placeholder
tokens need to be filled with multimodal embeddings.
Shape: `(num_placeholders, num_embeds)`
"""
if
is_embed
is
None
:
return
embeds
placeholders
=
embeds
.
new_full
(
(
is_embed
.
shape
[
0
],
embeds
.
shape
[
-
1
]),
fill_value
=
torch
.
nan
,
)
placeholders
[
is_embed
]
=
embeds
return
placeholders
def
gather_mm_placeholders
(
placeholders
:
torch
.
Tensor
,
is_embed
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""
Reconstructs the embeddings from the placeholder tokens.
This is the operation of :func:`scatter_mm_placeholders`.
"""
if
is_embed
is
None
:
return
placeholders
return
placeholders
[
is_embed
]
vllm/worker/enc_dec_model_runner.py
View file @
31330101
...
@@ -16,6 +16,7 @@ from vllm.config import VllmConfig
...
@@ -16,6 +16,7 @@ from vllm.config import VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalKwargs
,
...
@@ -34,6 +35,7 @@ from vllm.worker.model_runner_base import (
...
@@ -34,6 +35,7 @@ from vllm.worker.model_runner_base import (
from
vllm.worker.utils
import
assert_enc_dec_mr_supported_scenario
from
vllm.worker.utils
import
assert_enc_dec_mr_supported_scenario
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
LORA_WARMUP_RANK
=
8
@
dataclasses
.
dataclass
(
frozen
=
True
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
...
@@ -160,7 +162,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -160,7 +162,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
if
num_steps
>
1
:
if
num_steps
>
1
:
raise
ValueError
(
"num_steps > 1 is not supported in "
raise
ValueError
(
"num_steps > 1 is not supported in "
"EncoderDecoderModelRunner"
)
"EncoderDecoderModelRunner"
)
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
if
(
model_input
.
attn_metadata
is
not
None
if
(
model_input
.
attn_metadata
is
not
None
and
model_input
.
attn_metadata
.
prefill_metadata
is
None
and
model_input
.
attn_metadata
.
prefill_metadata
is
None
and
model_input
.
attn_metadata
.
decode_metadata
.
use_cuda_graph
):
and
model_input
.
attn_metadata
.
decode_metadata
.
use_cuda_graph
):
...
@@ -268,6 +274,22 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -268,6 +274,22 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, and therefore the max amount of
# memory consumption. Create dummy lora request copies from the
# lora request passed in, which contains a lora from the lora
# warmup path.
dummy_lora_requests
:
List
[
LoRARequest
]
=
[]
dummy_lora_requests_per_seq
:
List
[
LoRARequest
]
=
[]
if
self
.
lora_config
:
dummy_lora_requests
=
self
.
_add_dummy_loras
(
self
.
lora_config
.
max_loras
)
assert
len
(
dummy_lora_requests
)
==
self
.
lora_config
.
max_loras
dummy_lora_requests_per_seq
=
[
dummy_lora_requests
[
idx
%
len
(
dummy_lora_requests
)]
for
idx
in
range
(
max_num_seqs
)
]
# Profile memory usage with max_num_sequences sequences and the total
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
# number of tokens equal to max_num_batched_tokens.
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
...
@@ -315,6 +337,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -315,6 +337,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
block_tables
=
None
,
block_tables
=
None
,
encoder_seq_data
=
encoder_dummy_data
.
seq_data
,
encoder_seq_data
=
encoder_dummy_data
.
seq_data
,
cross_block_table
=
None
,
cross_block_table
=
None
,
lora_request
=
dummy_lora_requests_per_seq
[
group_id
]
if
dummy_lora_requests_per_seq
else
None
,
multi_modal_data
=
decoder_dummy_data
.
multi_modal_data
multi_modal_data
=
decoder_dummy_data
.
multi_modal_data
or
encoder_dummy_data
.
multi_modal_data
,
or
encoder_dummy_data
.
multi_modal_data
,
multi_modal_placeholders
=
decoder_dummy_data
.
multi_modal_placeholders
=
decoder_dummy_data
.
...
...
vllm/worker/hpu_model_runner.py
View file @
31330101
...
@@ -32,6 +32,7 @@ from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
...
@@ -32,6 +32,7 @@ from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
DeviceConfig
,
VllmConfig
from
vllm.config
import
DeviceConfig
,
VllmConfig
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.distributed.parallel_state
import
get_world_group
from
vllm.distributed.parallel_state
import
get_world_group
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -44,11 +45,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput
...
@@ -44,11 +45,13 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SequenceGroupToSample
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalKwargs
)
MultiModalKwargs
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
IntermediateTensors
,
SequenceData
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
SequenceGroupMetadata
)
Logprob
,
SequenceData
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.utils
import
(
bind_kv_cache
,
is_pin_memory_available
,
from
vllm.utils
import
(
bind_kv_cache
,
is_pin_memory_available
,
make_tensor_with_pad
)
make_tensor_with_pad
)
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
...
@@ -100,7 +103,10 @@ def subtuple(obj: object,
...
@@ -100,7 +103,10 @@ def subtuple(obj: object,
if
to_override
is
None
:
if
to_override
is
None
:
to_override
=
{}
to_override
=
{}
fields
=
set
(
to_copy
)
|
set
(
to_override
.
keys
())
fields
=
set
(
to_copy
)
|
set
(
to_override
.
keys
())
values
=
{
f
:
to_override
.
get
(
f
,
getattr
(
obj
,
f
))
for
f
in
fields
}
if
type
(
obj
)
is
dict
:
values
=
{
key
:
obj
[
key
]
for
key
in
fields
if
key
in
obj
}
else
:
values
=
{
f
:
to_override
.
get
(
f
,
getattr
(
obj
,
f
))
for
f
in
fields
}
if
typename
not
in
_TYPE_CACHE
:
if
typename
not
in
_TYPE_CACHE
:
_TYPE_CACHE
[
typename
]
=
collections
.
namedtuple
(
typename
,
_TYPE_CACHE
[
typename
]
=
collections
.
namedtuple
(
typename
,
' '
.
join
(
fields
))
' '
.
join
(
fields
))
...
@@ -533,6 +539,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
...
@@ -533,6 +539,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
virtual_engine
:
int
=
0
virtual_engine
:
int
=
0
lora_ids
:
Optional
[
List
[
int
]]
=
None
lora_ids
:
Optional
[
List
[
int
]]
=
None
async_callback
:
Optional
[
Callable
]
=
None
async_callback
:
Optional
[
Callable
]
=
None
is_first_multi_step
:
bool
=
True
is_last_step
:
bool
=
True
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
tensor_dict
=
{
...
@@ -545,6 +553,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
...
@@ -545,6 +553,8 @@ class ModelInputForHPU(ModelRunnerInputBase):
"batch_size_padded"
:
self
.
batch_size_padded
,
"batch_size_padded"
:
self
.
batch_size_padded
,
"virtual_engine"
:
self
.
virtual_engine
,
"virtual_engine"
:
self
.
virtual_engine
,
"lora_ids"
:
self
.
lora_ids
,
"lora_ids"
:
self
.
lora_ids
,
"is_first_multi_step"
:
self
.
is_first_multi_step
,
"is_last_step"
:
self
.
is_last_step
,
}
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
return
tensor_dict
return
tensor_dict
...
@@ -656,6 +666,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -656,6 +666,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self
.
_set_gc_threshold
()
self
.
_set_gc_threshold
()
self
.
use_contiguous_pa
=
envs
.
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
self
.
use_contiguous_pa
=
envs
.
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
# For multi-step scheduling
self
.
cached_step_outputs
:
List
[
torch
.
Tensor
]
=
[]
def
_set_gc_threshold
(
self
)
->
None
:
def
_set_gc_threshold
(
self
)
->
None
:
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
# for comprehensive description of gc generations.
# for comprehensive description of gc generations.
...
@@ -1005,6 +1018,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1005,6 +1018,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
def
_prepare_decode
(
def
_prepare_decode
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
output
=
None
,
)
->
PrepareDecodeMetadata
:
)
->
PrepareDecodeMetadata
:
input_tokens
:
List
[
List
[
int
]]
=
[]
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
...
@@ -1035,8 +1049,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1035,8 +1049,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
for
seq_id
in
seq_ids
:
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
if
output
is
None
:
input_tokens
.
append
([
generation_token
])
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
([
generation_token
])
seq_len
=
seq_data
.
get_len
()
seq_len
=
seq_data
.
get_len
()
position
=
seq_len
-
1
position
=
seq_len
-
1
...
@@ -1047,6 +1062,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1047,6 +1062,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
seq_lens
.
append
(
seq_len
)
seq_lens
.
append
(
seq_len
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
num_fully_occupied_blocks
=
position
//
self
.
block_size
block_table
=
block_table
[:
num_fully_occupied_blocks
+
1
]
if
len
(
block_table
)
==
0
:
if
len
(
block_table
)
==
0
:
block_number
=
_PAD_BLOCK_ID
block_number
=
_PAD_BLOCK_ID
else
:
else
:
...
@@ -1066,9 +1084,14 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1066,9 +1084,14 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
block_tables
.
append
(
block_table
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
if
output
is
None
:
dtype
=
torch
.
long
,
input_tokens
=
torch
.
tensor
(
input_tokens
,
device
=
self
.
device
)
dtype
=
torch
.
long
,
device
=
self
.
device
)
else
:
real_batch_size
=
len
(
seq_group_metadata_list
)
input_tokens
=
output
[:
real_batch_size
]
input_positions
=
torch
.
tensor
(
input_positions
,
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
device
=
self
.
device
)
...
@@ -1462,7 +1485,27 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1462,7 +1485,27 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
profiler
.
start
()
profiler
.
start
()
for
_
in
range
(
times
):
for
_
in
range
(
times
):
inputs
=
self
.
prepare_model_input
(
seqs
)
inputs
=
self
.
prepare_model_input
(
seqs
)
self
.
execute_model
(
inputs
,
None
,
warmup_mode
=
True
)
is_single_step
=
\
self
.
vllm_config
.
scheduler_config
.
num_scheduler_steps
==
1
if
is_prompt
or
is_single_step
:
self
.
execute_model
(
inputs
,
None
,
warmup_mode
=
True
)
else
:
# decode with multi-step
inputs
=
dataclasses
.
replace
(
inputs
,
is_first_multi_step
=
True
,
is_last_step
=
False
)
self
.
execute_model
(
inputs
,
None
,
warmup_mode
=
True
,
num_steps
=
2
,
seqs
=
seqs
)
inputs
=
dataclasses
.
replace
(
inputs
,
is_first_multi_step
=
False
,
is_last_step
=
True
)
self
.
execute_model
(
inputs
,
None
,
warmup_mode
=
True
,
num_steps
=
2
,
seqs
=
seqs
)
torch
.
hpu
.
synchronize
()
torch
.
hpu
.
synchronize
()
if
profiler
:
if
profiler
:
profiler
.
step
()
profiler
.
step
()
...
@@ -1985,115 +2028,273 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
...
@@ -1985,115 +2028,273 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
num_steps
:
int
=
1
,
warmup_mode
=
False
,
warmup_mode
=
False
,
seqs
=
None
,
)
->
Optional
[
Union
[
List
[
SamplerOutput
],
IntermediateTensors
]]:
)
->
Optional
[
Union
[
List
[
SamplerOutput
],
IntermediateTensors
]]:
if
num_steps
>
1
:
if
not
model_input
.
is_first_multi_step
:
raise
ValueError
(
if
not
model_input
.
is_last_step
:
"num_steps > 1 is not supported in HPUModelRunner"
)
# not first or last multi-step
return
[]
# last multi-step
output
=
self
.
_decode_sampler_outputs
(
model_input
)
if
self
.
is_driver_worker
else
[]
torch
.
hpu
.
synchronize
()
if
model_input
.
is_first_multi_step
:
# first multi-step
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
input_tokens
=
model_input
.
input_tokens
input_positions
=
model_input
.
input_positions
attn_metadata
=
model_input
.
attn_metadata
sampling_metadata
=
model_input
.
sampling_metadata
real_batch_size
=
model_input
.
real_batch_size
batch_size_padded
=
model_input
.
batch_size_padded
assert
input_tokens
is
not
None
assert
input_positions
is
not
None
assert
sampling_metadata
is
not
None
assert
attn_metadata
is
not
None
is_prompt
=
attn_metadata
.
is_prompt
assert
is_prompt
is
not
None
batch_size
=
input_tokens
.
size
(
0
)
seq_len
=
self
.
_seq_len
(
attn_metadata
)
use_graphs
=
self
.
_use_graphs
(
batch_size
,
seq_len
,
is_prompt
)
self
.
_check_config
(
batch_size
,
seq_len
,
is_prompt
,
warmup_mode
)
lora_mask
:
torch
.
Tensor
=
None
lora_logits_mask
:
torch
.
Tensor
=
None
if
self
.
lora_config
:
assert
model_input
.
lora_ids
is
not
None
lora_mask
,
lora_logits_mask
=
self
.
create_lora_mask
(
input_tokens
,
model_input
.
lora_ids
,
attn_metadata
.
is_prompt
)
execute_model_kwargs
=
{
"input_ids"
:
input_tokens
,
"positions"
:
input_positions
,
"attn_metadata"
:
self
.
trim_attn_metadata
(
attn_metadata
),
"intermediate_tensors"
:
intermediate_tensors
,
"lora_mask"
:
lora_mask
,
"virtual_engine"
:
model_input
.
virtual_engine
,
**
(
model_input
.
multi_modal_kwargs
or
{}),
}
if
htorch
.
utils
.
internal
.
is_lazy
():
execute_model_kwargs
.
update
(
{
"bypass_hpu_graphs"
:
not
use_graphs
})
if
self
.
lora_config
:
htorch
.
core
.
mark_step
()
assert
model_input
.
lora_requests
is
not
None
if
self
.
is_driver_worker
:
assert
model_input
.
lora_mapping
is
not
None
model_event_name
=
(
"model_"
self
.
set_active_loras
(
model_input
.
lora_requests
,
f
"
{
'prompt'
if
is_prompt
else
'decode'
}
_"
model_input
.
lora_mapping
)
f
"bs
{
batch_size
}
_"
input_tokens
=
model_input
.
input_tokens
f
"seq
{
seq_len
}
_"
input_positions
=
model_input
.
input_positions
f
"graphs
{
'T'
if
use_graphs
else
'F'
}
"
)
attn_metadata
=
model_input
.
attn_metadata
else
:
sampling_metadata
=
model_input
.
sampling_metadata
model_event_name
=
'model_executable'
real_batch_size
=
model_input
.
real_batch_size
if
num_steps
>
1
:
batch_size_padded
=
model_input
.
batch_size_padded
# in case of multi-step scheduling
assert
input_tokens
is
not
None
# we only want to pythonize in the last step
assert
input_positions
is
not
None
sampling_metadata
.
skip_sampler_cpu_output
=
True
assert
sampling_metadata
is
not
None
self
.
model
.
model
.
sampler
.
include_gpu_probs_tensor
=
True
assert
attn_metadata
is
not
None
cache_orig_output_tokens_len
:
List
[
Dict
]
=
[]
is_prompt
=
attn_metadata
.
is_prompt
assert
is_prompt
is
not
None
def
try_revert_dummy_output_tokens
():
batch_size
=
input_tokens
.
size
(
0
)
if
len
(
cache_orig_output_tokens_len
)
>
0
:
seq_len
=
self
.
_seq_len
(
attn_metadata
)
# Reuse the original output token ids length
use_graphs
=
self
.
_use_graphs
(
batch_size
,
seq_len
,
is_prompt
)
for
i
,
seq_group_metadata
in
enumerate
(
self
.
_check_config
(
batch_size
,
seq_len
,
is_prompt
,
warmup_mode
)
seq_group_metadata_list
):
for
j
,
data
in
seq_group_metadata
.
seq_data
.
items
():
orig_output_tokens_len
=
\
cache_orig_output_tokens_len
[
i
][
j
]
data
.
output_token_ids
=
\
data
.
output_token_ids
[:
orig_output_tokens_len
]
for
i
in
range
(
num_steps
):
if
i
!=
0
and
not
self
.
is_driver_worker
:
broadcast_data
=
broadcast_tensor_dict
(
src
=
0
)
if
'early_exit'
in
broadcast_data
and
broadcast_data
[
'early_exit'
]:
return
[
output
]
if
num_steps
==
1
else
[]
execute_model_kwargs
.
update
({
"input_ids"
:
broadcast_data
[
"input_ids"
],
"positions"
:
broadcast_data
[
"positions"
],
"attn_metadata"
:
self
.
trim_attn_metadata
(
broadcast_data
[
"attn_metadata"
])
})
with
self
.
profiler
.
record_event
(
'internal'
,
model_event_name
):
hidden_states
=
self
.
model
.
forward
(
**
execute_model_kwargs
,
selected_token_indices
=
sampling_metadata
.
selected_token_indices
)
if
self
.
lora_config
:
LoraMask
.
setLoraMask
(
lora_logits_mask
.
index_select
(
0
,
sampling_metadata
.
selected_token_indices
))
# Compute the logits.
with
self
.
profiler
.
record_event
(
'internal'
,
(
'compute_logits_'
f
'
{
"prompt"
if
is_prompt
else
"decode"
}
_bs'
f
'
{
batch_size
}
_'
f
'seq
{
seq_len
}
'
)):
if
num_steps
==
1
:
sampling_metadata
.
selected_token_indices
=
None
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
htorch
.
core
.
mark_step
()
# Only perform sampling in the driver worker.
if
not
self
.
is_driver_worker
:
continue
lora_mask
:
torch
.
Tensor
=
None
if
model_input
.
async_callback
is
not
None
:
lora_logits_mask
:
torch
.
Tensor
=
None
model_input
.
async_callback
()
if
self
.
lora_config
:
# Sample the next token.
assert
model_input
.
lora_ids
is
not
None
with
self
.
profiler
.
record_event
(
lora_mask
,
lora_logits_mask
=
self
.
create_lora_mask
(
'internal'
,
(
'sample_'
input_tokens
,
model_input
.
lora_ids
,
attn_metadata
.
is_prompt
)
f
'
{
"prompt"
if
is_prompt
else
"decode"
}
_'
f
'bs
{
batch_size
}
_'
execute_model_kwargs
=
{
f
'seq
{
seq_len
}
'
)):
"input_ids"
:
input_tokens
,
output
=
self
.
model
.
sample
(
"positions"
:
input_positions
,
logits
=
logits
,
"attn_metadata"
:
self
.
trim_attn_metadata
(
attn_metadata
),
sampling_metadata
=
sampling_metadata
,
"intermediate_tensors"
:
intermediate_tensors
,
)
"lora_mask"
:
lora_mask
,
if
num_steps
>
1
:
"virtual_engine"
:
model_input
.
virtual_engine
,
output
=
output
.
sampled_token_ids
**
(
model_input
.
multi_modal_kwargs
or
{}),
self
.
cached_step_outputs
.
append
(
}
output
.
detach
().
clone
())
if
htorch
.
utils
.
internal
.
is_lazy
():
htorch
.
core
.
mark_step
()
execute_model_kwargs
.
update
({
"bypass_hpu_graphs"
:
not
use_graphs
})
if
i
<
num_steps
-
1
:
if
i
==
0
:
htorch
.
core
.
mark_step
()
if
model_input
.
async_callback
is
not
None
:
if
self
.
is_driver_worker
:
ctx
=
model_input
.
async_callback
.
keywords
[
# type: ignore
model_event_name
=
(
"model_"
"ctx"
]
f
"
{
'prompt'
if
is_prompt
else
'decode'
}
_"
seq_group_metadata_list
=
\
f
"bs
{
batch_size
}
_"
ctx
.
seq_group_metadata_list
f
"seq
{
seq_len
}
_"
elif
seqs
is
not
None
:
f
"graphs
{
'T'
if
use_graphs
else
'F'
}
"
)
seq_group_metadata_list
=
seqs
else
:
raise
RuntimeError
(
"seq_group_metadata_list is uninitialized"
)
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
# Skip empty steps
seq_group_metadata
.
state
.
current_step
+=
(
num_steps
-
2
)
# Cache the original output token ids
cache_orig_output_tokens_len
.
append
({})
for
j
,
data
in
seq_group_metadata
.
seq_data
.
items
():
cache_orig_output_tokens_len
[
i
][
j
]
=
\
len
(
data
.
output_token_ids
)
for
seq_group_metadata
in
seq_group_metadata_list
:
for
data
in
seq_group_metadata
.
seq_data
.
values
():
max_output_len
=
sampling_metadata
.
seq_groups
[
0
].
sampling_params
.
max_tokens
if
len
(
data
.
output_token_ids
)
<
max_output_len
-
1
:
# add a place holder for prepare_decode
# arbitrary value, this could be any token
dummy_token
=
(
540
,
)
data
.
output_token_ids
+=
(
dummy_token
)
else
:
broadcast_tensor_dict
({
'early_exit'
:
True
},
src
=
0
)
if
num_steps
==
1
:
return
[
output
]
else
:
try_revert_dummy_output_tokens
()
return
[]
result
=
self
.
_prepare_decode
(
seq_group_metadata_list
,
output
=
output
)
execute_model_kwargs
.
update
({
"input_ids"
:
result
.
input_tokens
,
"positions"
:
result
.
input_positions
,
"attn_metadata"
:
self
.
trim_attn_metadata
(
result
.
attn_metadata
)
})
model_kwargs_broadcast_data
=
{
"input_ids"
:
result
.
input_tokens
,
"positions"
:
result
.
input_positions
,
"attn_metadata"
:
vars
(
result
.
attn_metadata
)
}
broadcast_tensor_dict
(
model_kwargs_broadcast_data
,
src
=
0
)
else
:
try_revert_dummy_output_tokens
()
if
self
.
is_driver_worker
and
self
.
profiler
.
enabled
:
# Stop recording 'execute_model' event
self
.
profiler
.
end
()
event_end
=
self
.
profiler
.
get_timestamp_us
()
counters
=
self
.
profiler_counter_helper
.
get_counter_dict
(
cache_config
=
self
.
cache_config
,
duration
=
event_end
-
self
.
event_start
,
seq_len
=
seq_len
,
batch_size_padded
=
batch_size_padded
,
real_batch_size
=
real_batch_size
,
is_prompt
=
is_prompt
)
self
.
profiler
.
record_counter
(
self
.
event_start
,
counters
)
if
num_steps
==
1
:
return
[
output
]
if
self
.
is_driver_worker
else
[]
else
:
return
[]
return
output
if
type
(
output
)
is
list
else
[
output
]
def
_decode_sampler_outputs
(
self
,
model_input
):
use_async_out_proc
=
model_input
.
async_callback
is
not
None
sampler_outputs
=
[]
num_outputs
=
len
(
self
.
cached_step_outputs
)
for
i
in
range
(
num_outputs
):
next_token_ids
=
self
.
cached_step_outputs
.
pop
(
0
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
sampler_output
=
self
.
_make_decode_output
(
next_token_ids
,
model_input
.
sampling_metadata
.
seq_groups
)
sampler_outputs
.
append
(
sampler_output
)
if
i
<
num_outputs
-
1
and
use_async_out_proc
:
assert
model_input
.
async_callback
is
not
None
ctx
=
model_input
.
async_callback
.
keywords
[
# type: ignore
"ctx"
]
ctx
.
append_output
(
outputs
=
[
sampler_output
],
seq_group_metadata_list
=
ctx
.
seq_group_metadata_list
,
scheduler_outputs
=
ctx
.
scheduler_outputs
,
is_async
=
False
,
is_last_step
=
False
,
is_first_step_output
=
False
)
model_input
.
async_callback
()
if
use_async_out_proc
:
return
[
sampler_outputs
[
-
1
]]
else
:
else
:
model_event_name
=
'model_executable'
return
sampler_outputs
with
self
.
profiler
.
record_event
(
'internal'
,
model_event_name
):
hidden_states
=
self
.
model
.
forward
(
**
execute_model_kwargs
,
selected_token_indices
=
sampling_metadata
.
selected_token_indices
)
if
self
.
lora_config
:
def
_make_decode_output
(
LoraMask
.
setLoraMask
(
self
,
lora_logits_mask
.
index_select
(
next_token_ids
:
List
[
List
[
int
]],
0
,
sampling_metadata
.
selected_token_indices
))
seq_groups
:
List
[
SequenceGroupToSample
],
)
->
SamplerOutput
:
# Compute the logits.
zero_logprob
=
Logprob
(
0.0
)
with
self
.
profiler
.
record_event
(
sampler_outputs
=
[]
'internal'
,
(
'compute_logits_'
batch_idx
=
0
f
'
{
"prompt"
if
is_prompt
else
"decode"
}
_bs'
for
seq_group
in
seq_groups
:
f
'
{
batch_size
}
_'
seq_ids
=
seq_group
.
seq_ids
f
'seq
{
seq_len
}
'
)):
seq_outputs
=
[]
sampling_metadata
.
selected_token_indices
=
None
for
seq_id
in
seq_ids
:
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
next_token_id
=
next_token_ids
[
batch_idx
][
0
]
sampling_metadata
)
seq_outputs
.
append
(
htorch
.
core
.
mark_step
()
SequenceOutput
(
seq_id
,
next_token_id
,
# Only perform sampling in the driver worker.
{
next_token_id
:
zero_logprob
}))
if
not
self
.
is_driver_worker
:
batch_idx
+=
1
return
[]
sampler_outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
if
model_input
.
async_callback
is
not
None
:
return
SamplerOutput
(
sampler_outputs
)
model_input
.
async_callback
()
# Sample the next token.
with
self
.
profiler
.
record_event
(
'internal'
,
(
'sample_'
f
'
{
"prompt"
if
is_prompt
else
"decode"
}
_'
f
'bs
{
batch_size
}
_'
f
'seq
{
seq_len
}
'
)):
output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
sampling_metadata
,
)
output
.
outputs
=
output
.
outputs
[:
real_batch_size
]
htorch
.
core
.
mark_step
()
if
self
.
is_driver_worker
and
self
.
profiler
.
enabled
:
# Stop recording 'execute_model' event
self
.
profiler
.
end
()
event_end
=
self
.
profiler
.
get_timestamp_us
()
counters
=
self
.
profiler_counter_helper
.
get_counter_dict
(
cache_config
=
self
.
cache_config
,
duration
=
event_end
-
self
.
event_start
,
seq_len
=
seq_len
,
batch_size_padded
=
batch_size_padded
,
real_batch_size
=
real_batch_size
,
is_prompt
=
is_prompt
)
self
.
profiler
.
record_counter
(
self
.
event_start
,
counters
)
return
[
output
]
def
shutdown_inc
(
self
):
def
shutdown_inc
(
self
):
can_finalize_inc
=
False
can_finalize_inc
=
False
...
...
vllm/worker/model_runner.py
View file @
31330101
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
sys
import
dataclasses
import
dataclasses
import
gc
import
gc
import
inspect
import
inspect
...
@@ -15,7 +16,7 @@ import numpy as np
...
@@ -15,7 +16,7 @@ import numpy as np
import
torch
import
torch
import
torch.distributed
import
torch.distributed
import
torch.nn
as
nn
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
.auto
import
tqdm
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
...
@@ -1108,6 +1109,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1108,6 +1109,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if
hasattr
(
self
,
"_builder_cls"
):
if
hasattr
(
self
,
"_builder_cls"
):
# multi-step model runner does not have `_builder_cls`
# multi-step model runner does not have `_builder_cls`
self
.
builder
=
self
.
_builder_cls
(
weakref
.
proxy
(
self
))
self
.
builder
=
self
.
_builder_cls
(
weakref
.
proxy
(
self
))
self
.
enforce_eager_bs_threshould
=
sys
.
maxsize
if
envs
.
VLLM_ENFORCE_EAGER_BS_THRESHOLD
is
not
None
and
envs
.
VLLM_ENFORCE_EAGER_BS_THRESHOLD
>
0
:
self
.
enforce_eager_bs_threshould
=
envs
.
VLLM_ENFORCE_EAGER_BS_THRESHOLD
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
...
@@ -1717,7 +1722,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1717,7 +1722,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
# virtual engines share the same kv cache.
# virtual engines share the same kv cache.
virtual_engine
=
model_input
.
virtual_engine
virtual_engine
=
model_input
.
virtual_engine
previous_hidden_states
=
kwargs
.
get
(
"previous_hidden_states"
)
previous_hidden_states
=
kwargs
.
get
(
"previous_hidden_states"
)
if
prefill_meta
is
None
and
decode_meta
.
use_cuda_graph
:
if
prefill_meta
is
None
and
decode_meta
.
use_cuda_graph
and
\
model_input
.
input_tokens
.
shape
[
0
]
<=
self
.
enforce_eager_bs_threshould
:
assert
model_input
.
input_tokens
is
not
None
assert
model_input
.
input_tokens
is
not
None
graph_batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
graph_batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
model_executable
=
self
.
graph_runners
[
virtual_engine
][
model_executable
=
self
.
graph_runners
[
virtual_engine
][
...
...
vllm/worker/multi_step_hpu_worker.py
0 → 100644
View file @
31330101
# SPDX-License-Identifier: Apache-2.0
###############################################################################
# Copyright (C) 2025 Habana Labs, Ltd. an Intel Company
###############################################################################
import
dataclasses
from
typing
import
Dict
,
Optional
,
Tuple
import
torch
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.worker.hpu_model_runner
import
ModelInputForHPU
from
vllm.worker.hpu_worker
import
HPUWorker
from
vllm.worker.worker_base
import
WorkerInput
class
MultiStepHPUWorker
(
HPUWorker
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
cached_model_input
:
Optional
[
ModelInputForHPU
]
=
None
def
_get_driver_input_and_broadcast
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
Tuple
[
ModelInputForHPU
,
WorkerInput
,
Dict
[
str
,
torch
.
Tensor
]]:
"""
Get the driver input and broadcast it to other workers.
"""
assert
self
.
is_driver_worker
assert
execute_model_req
.
virtual_engine
==
0
is_first_multi_step
=
execute_model_req
.
is_first_multi_step
is_last_step
=
execute_model_req
.
is_last_step
if
is_first_multi_step
:
# on first step we prepare the worker input and model input normally
worker_input
:
WorkerInput
=
self
.
prepare_worker_input
(
execute_model_req
=
execute_model_req
)
worker_input
=
dataclasses
.
replace
(
worker_input
,
num_steps
=
execute_model_req
.
num_lookahead_slots
+
1
)
model_input
:
ModelInputForHPU
=
(
self
.
model_runner
.
prepare_model_input
(
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
finished_requests_ids
))
if
execute_model_req
.
async_callback
:
model_input
=
dataclasses
.
replace
(
model_input
,
async_callback
=
execute_model_req
.
async_callback
)
else
:
# on subsequent steps we reuse the worker input and model input
assert
self
.
cached_model_input
is
not
None
model_input
=
self
.
cached_model_input
worker_input
=
WorkerInput
()
model_input
=
dataclasses
.
replace
(
model_input
,
is_first_multi_step
=
is_first_multi_step
,
is_last_step
=
is_last_step
)
if
self
.
do_metadata_broadcast
:
if
is_first_multi_step
:
broadcast_data
=
worker_input
.
as_broadcastable_tensor_dict
()
broadcast_data
.
update
(
model_input
.
as_broadcastable_tensor_dict
())
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
else
:
broadcast_data
=
{
"is_first_multi_step"
:
is_first_multi_step
,
"is_last_step"
:
is_last_step
,
}
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
# Returning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return
model_input
,
worker_input
,
{}
def
prepare_input
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
)
->
Optional
[
Tuple
[
ModelInputForHPU
,
WorkerInput
,
Dict
[
str
,
torch
.
Tensor
]]]:
if
self
.
is_driver_worker
:
if
execute_model_req
is
None
:
if
self
.
do_metadata_broadcast
:
# This signals that there's no more requests to process for
# now. All workers are running infinite loop with
# broadcast_tensor_dict, and it stops the loop when the
# driver broadcasts an empty input. Send an empty input to
# notify all other workers to stop their execution loop.
broadcast_tensor_dict
({},
src
=
0
)
return
None
model_input
,
worker_input
,
_
=
self
.
_get_driver_input_and_broadcast
(
execute_model_req
)
if
model_input
.
is_first_multi_step
:
self
.
cached_model_input
=
model_input
return
model_input
,
worker_input
,
{}
else
:
broadcast_data
=
broadcast_tensor_dict
(
src
=
0
)
if
not
broadcast_data
:
return
None
if
len
(
broadcast_data
)
==
2
:
assert
self
.
cached_model_input
is
not
None
self
.
cached_model_input
=
dataclasses
.
replace
(
self
.
cached_model_input
,
is_first_multi_step
=
broadcast_data
[
"is_first_multi_step"
],
is_last_step
=
broadcast_data
[
"is_last_step"
])
empty_worker_input
=
WorkerInput
()
return
self
.
cached_model_input
,
empty_worker_input
,
{}
worker_input
=
WorkerInput
.
from_broadcasted_tensor_dict
(
broadcast_data
)
model_input
=
(
self
.
model_runner
.
make_model_input_from_broadcasted_tensor_dict
(
broadcast_data
))
self
.
cached_model_input
=
model_input
return
model_input
,
worker_input
,
{}
Prev
1
…
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment