Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b8025f24
Commit
b8025f24
authored
Jan 23, 2026
by
zhuwenwen
Browse files
remove unused code
parent
3b2aefb1
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
2 additions
and
3730 deletions
+2
-3730
CMakeLists.txt
CMakeLists.txt
+1
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+1
-9
vllm/worker/cpu_enc_dec_model_runner.py
vllm/worker/cpu_enc_dec_model_runner.py
+0
-326
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+0
-671
vllm/worker/cpu_pooling_model_runner.py
vllm/worker/cpu_pooling_model_runner.py
+0
-125
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+0
-452
vllm/worker/multi_step_tpu_worker.py
vllm/worker/multi_step_tpu_worker.py
+0
-108
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+0
-909
vllm/worker/tpu_worker.py
vllm/worker/tpu_worker.py
+0
-337
vllm/worker/xpu_model_runner.py
vllm/worker/xpu_model_runner.py
+0
-606
vllm/worker/xpu_worker.py
vllm/worker/xpu_worker.py
+0
-186
No files found.
CMakeLists.txt
View file @
b8025f24
...
@@ -37,7 +37,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
...
@@ -37,7 +37,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
set
(
PYTHON_SUPPORTED_VERSIONS
"3.10"
"3.11"
"3.12"
"3.13"
)
set
(
PYTHON_SUPPORTED_VERSIONS
"3.10"
"3.11"
"3.12"
"3.13"
)
# Supported AMD GPU architectures.
# Supported AMD GPU architectures.
set
(
HIP_SUPPORTED_ARCHS
"gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151;gfx928;gfx936"
)
set
(
HIP_SUPPORTED_ARCHS
"gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151;gfx928;gfx936
;gfx938
"
)
# ROCm installation prefix. Default to /opt/rocm but allow override via
# ROCm installation prefix. Default to /opt/rocm but allow override via
# -DROCM_PATH=/your/rocm/path when invoking cmake.
# -DROCM_PATH=/your/rocm/path when invoking cmake.
...
...
vllm/_custom_ops.py
View file @
b8025f24
...
@@ -415,6 +415,7 @@ def apply_repetition_penalties(
...
@@ -415,6 +415,7 @@ def apply_repetition_penalties(
logits
,
prompt_mask
,
output_mask
,
repetition_penalties
logits
,
prompt_mask
,
output_mask
,
repetition_penalties
)
)
# fused quant layer norm ops
# fused quant layer norm ops
def
rms_norm_dynamic_per_token_quant
(
def
rms_norm_dynamic_per_token_quant
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
...
@@ -2539,15 +2540,6 @@ def cp_gather_indexer_k_quant_cache(
...
@@ -2539,15 +2540,6 @@ def cp_gather_indexer_k_quant_cache(
)
)
def
indexer_k_quant_and_cache
(
k
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
quant_block_size
:
int
,
kv_cache_dtype
:
str
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
indexer_k_quant_and_cache
(
k
,
kv_cache
,
slot_mapping
,
quant_block_size
,
kv_cache_dtype
)
def
get_device_attribute
(
attribute
:
int
,
device
:
int
)
->
int
:
def
get_device_attribute
(
attribute
:
int
,
device
:
int
)
->
int
:
return
torch
.
ops
.
_C_cuda_utils
.
get_device_attribute
(
attribute
,
device
)
return
torch
.
ops
.
_C_cuda_utils
.
get_device_attribute
(
attribute
,
device
)
...
...
vllm/worker/cpu_enc_dec_model_runner.py
deleted
100644 → 0
View file @
3b2aefb1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
cast
import
torch
from
vllm.attention
import
AttentionMetadata
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
from
vllm.worker.cpu_model_runner
import
(
CPUModelRunnerBase
,
ModelInputForCPUBuilder
,
ModelInputForCPUWithSamplingMetadata
)
from
vllm.worker.model_runner_base
import
(
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
)
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
EncoderDecoderModelInputForCPU
(
ModelInputForCPUWithSamplingMetadata
):
"""
Used by the EncoderDecoderModelRunner.
"""
encoder_input_tokens
:
Optional
[
torch
.
Tensor
]
=
None
encoder_input_positions
:
Optional
[
torch
.
Tensor
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
"encoder_input_tokens"
:
self
.
encoder_input_tokens
,
"encoder_input_positions"
:
self
.
encoder_input_positions
,
"multi_modal_kwargs"
:
self
.
multi_modal_kwargs
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
_add_sampling_metadata_broadcastable_dict
(
tensor_dict
,
self
.
sampling_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
,
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
"EncoderDecoderModelInputForCPU"
:
return
cast
(
EncoderDecoderModelInputForCPU
,
super
().
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
))
class
CPUEncoderDecoderModelRunner
(
CPUModelRunnerBase
[
EncoderDecoderModelInputForCPU
]):
_model_input_cls
:
Type
[
EncoderDecoderModelInputForCPU
]
=
(
EncoderDecoderModelInputForCPU
)
_builder_cls
:
Type
[
ModelInputForCPUBuilder
]
=
ModelInputForCPUBuilder
def
_list_to_int32_tensor
(
self
,
_list
:
List
[
int
],
)
->
torch
.
Tensor
:
return
torch
.
tensor
(
_list
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
def
_list_to_long_tensor
(
self
,
_list
:
List
[
int
],
)
->
torch
.
Tensor
:
return
torch
.
tensor
(
_list
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
def
_empty_int32_tensor
(
self
)
->
torch
.
Tensor
:
return
self
.
_list_to_int32_tensor
([])
def
_empty_long_tensor
(
self
)
->
torch
.
Tensor
:
return
self
.
_list_to_long_tensor
([])
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
EncoderDecoderModelInputForCPU
:
return
EncoderDecoderModelInputForCPU
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
self
.
attn_backend
,
)
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
EncoderDecoderModelInputForCPU
:
model_input
=
self
.
_prepare_model_input_tensors
(
seq_group_metadata_list
,
finished_requests_ids
)
(
attn_metadata
,
encoder_input_tokens_tensor
,
encoder_input_positions_tensor
,
)
=
self
.
_prepare_encoder_model_input_tensors
(
seq_group_metadata_list
,
model_input
)
# Sampling metadata is only required for the final pp group
generators
=
self
.
get_generators
(
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
model_input
.
seq_lens
,
model_input
.
query_lens
,
self
.
device
,
pin_memory
=
False
,
generators
=
generators
)
return
dataclasses
.
replace
(
model_input
,
sampling_metadata
=
sampling_metadata
,
attn_metadata
=
attn_metadata
,
encoder_input_tokens
=
encoder_input_tokens_tensor
,
encoder_input_positions
=
encoder_input_positions_tensor
,
virtual_engine
=
virtual_engine
,
)
def
_prepare_encoder_model_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
model_input
:
EncoderDecoderModelInputForCPU
,
)
->
Tuple
[
AttentionMetadata
,
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
"""Helper method to prepare the encoder- and cross-attn-related
model inputs based on a given sequence group. These additional inputs
are used to augment an already-computed `EncoderDecoderModelInput`
data structure which already has decoder-related model inputs
populated.
Sets the following attn_metadata fields:
* `num_encoder_tokens`
* `encoder_seq_lens`
* `encoder_seq_lens_tensor`
* `max_encoder_seq_len`
* `cross_slot_mapping`
* `cross_block_tables`
Constructs a new model inputs data structure, based on
(1) the existing fields in the `model_inputs` argument,
and (2) the following additional fields which are
computed (or in the case of `attn_metadata`, updated)
by this function:
* attn_metadata
* encoder_input_tokens
* encoder_input_positions
Arguments:
* seq_group_metadata_list: list of sequence groups for which to
compute inputs
* model_inputs: model inputs data structure with decoder-oriented
fields already computed.
Return:
* Updated model inputs data structure
"""
if
len
(
seq_group_metadata_list
)
==
0
:
return
(
model_input
.
attn_metadata
,
None
,
None
)
# Since we are not supporting chunked prefill either the entire
# batch is prefill or it is decode
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Build encoder inputs
encoder_seq_lens
:
List
[
int
]
=
[]
if
is_prompt
:
# Prefill phase.
cross_block_tables
=
self
.
_empty_int32_tensor
().
view
(
len
(
seq_group_metadata_list
),
-
1
)
# Extract input tokens/positions, cross-attention slot-mapping,
# & seq len from each sequence group metadata
(
encoder_input_tokens
,
encoder_input_positions
,
cross_slot_mapping
,
)
=
(
[],
[],
[],
)
for
seq_group_metadata
in
seq_group_metadata_list
:
# Build seq lens
seq_len
=
seq_group_metadata
.
encoder_seq_data
.
get_len
()
token_ids
=
seq_group_metadata
.
encoder_seq_data
.
get_token_ids
()
encoder_seq_lens
.
append
(
seq_len
)
# Build slot mapping
for
i
in
range
(
0
,
seq_len
):
block_number
=
seq_group_metadata
.
cross_block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
cross_slot_mapping
.
append
(
slot
)
# Build encoder input tokens
encoder_input_tokens
.
extend
(
token_ids
)
encoder_input_positions
.
extend
(
list
(
range
(
0
,
seq_len
)))
# Convert tokens/positions & cross-attention
# slot-mapping to encoder input tensors
encoder_input_tokens_tensor
=
self
.
_list_to_long_tensor
(
encoder_input_tokens
)
encoder_input_positions_tensor
=
self
.
_list_to_long_tensor
(
encoder_input_positions
)
cross_slot_mapping_tensor
=
self
.
_list_to_long_tensor
(
cross_slot_mapping
)
else
:
# Decode phase.
encoder_input_tokens_tensor
=
self
.
_empty_long_tensor
()
encoder_input_positions_tensor
=
self
.
_empty_long_tensor
()
cross_slot_mapping_tensor
=
self
.
_empty_long_tensor
()
# Extract cross-attention block tables &
# seq len from each sequence group metadata.
# Cross-attention block tables are empty
# during vLLM memory profiling.
cross_block_tables
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
for
_
in
range
(
len
(
seq_group_metadata
.
seq_data
)):
encoder_seq_lens
.
append
(
seq_group_metadata
.
encoder_seq_data
.
get_len
())
cross_block_table
=
seq_group_metadata
.
cross_block_table
cross_block_tables
.
append
([]
if
(
cross_block_table
is
None
)
else
cross_block_table
)
max_len_of_block_table
=
max
(
len
(
block_table
)
for
block_table
in
cross_block_tables
)
cross_block_tables
=
make_tensor_with_pad
(
cross_block_tables
,
max_len
=
max_len_of_block_table
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# Compute encoder sequence lengths & encoder
# sequence starting offset tensors
max_encoder_seq_len
=
max
(
encoder_seq_lens
,
default
=
0
)
encoder_seq_lens_tensor
=
self
.
_list_to_int32_tensor
(
encoder_seq_lens
)
encoder_seq_start_loc
=
torch
.
zeros
(
encoder_seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
cumsum
(
encoder_seq_lens_tensor
,
dim
=
0
,
dtype
=
encoder_seq_start_loc
.
dtype
,
out
=
encoder_seq_start_loc
[
1
:])
# Update attention metadata with encoder-oriented attributes
attn_metadata
=
model_input
.
attn_metadata
assert
attn_metadata
is
not
None
(
attn_metadata
.
num_encoder_tokens
,
attn_metadata
.
encoder_seq_lens
,
attn_metadata
.
encoder_seq_lens_tensor
,
attn_metadata
.
max_encoder_seq_len
,
attn_metadata
.
cross_slot_mapping
,
attn_metadata
.
cross_block_tables
,
)
=
(
sum
(
encoder_seq_lens
),
encoder_seq_lens
,
encoder_seq_lens_tensor
,
max_encoder_seq_len
,
cross_slot_mapping_tensor
,
cross_block_tables
,
)
return
(
attn_metadata
,
encoder_input_tokens_tensor
,
encoder_input_positions_tensor
)
@
torch
.
no_grad
()
def
execute_model
(
self
,
model_input
:
EncoderDecoderModelInputForCPU
,
kv_caches
:
List
[
torch
.
Tensor
],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
SamplerOutput
]]:
if
num_steps
>
1
:
raise
ValueError
(
"CPU worker does not support multi-step execution."
)
model_executable
=
self
.
model
execute_model_kwargs
=
{
"input_ids"
:
model_input
.
input_tokens
,
"positions"
:
model_input
.
input_positions
,
"encoder_input_ids"
:
model_input
.
encoder_input_tokens
,
"encoder_positions"
:
model_input
.
encoder_input_positions
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
device
=
self
.
device
,
),
"intermediate_tensors"
:
intermediate_tensors
,
}
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
model_input
.
sampling_metadata
)
# Only perform sampling in the driver worker.
if
not
self
.
is_driver_worker
:
return
[]
# Sample the next token.
output
=
self
.
sampler
(
logits
=
logits
,
sampling_metadata
=
model_input
.
sampling_metadata
,
)
return
[
output
]
vllm/worker/cpu_model_runner.py
deleted
100644 → 0
View file @
3b2aefb1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
import
weakref
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Set
,
Type
,
TypeVar
,
Union
)
import
torch
from
torch
import
nn
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models
import
supports_lora
,
supports_multimodal
from
vllm.multimodal
import
(
BatchedTensorInputs
,
MultiModalKwargs
,
MultiModalPlaceholderMap
)
from
vllm.sequence
import
(
IntermediateTensors
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_sampling_metadata_from_tensor_dict
)
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
logger
=
init_logger
(
__name__
)
TModelInputForCPU
=
TypeVar
(
'TModelInputForCPU'
,
bound
=
"ModelInputForCPU"
)
_PAD_SLOT_ID
=
-
1
@
dataclass
(
frozen
=
True
)
class
ModelInputForCPU
(
ModelRunnerInputBase
):
"""
Base class contains metadata needed for the base model forward pass on CPU
"""
input_tokens
:
Optional
[
torch
.
Tensor
]
=
None
input_positions
:
Optional
[
torch
.
Tensor
]
=
None
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
attn_metadata
:
Optional
[
"AttentionMetadata"
]
=
None
multi_modal_kwargs
:
Optional
[
BatchedTensorInputs
]
=
None
virtual_engine
:
Optional
[
int
]
=
None
seq_lens
:
Optional
[
List
[
int
]]
=
None
query_lens
:
Optional
[
List
[
int
]]
=
None
lora_mapping
:
Optional
[
"LoRAMapping"
]
=
None
lora_requests
:
Optional
[
Set
[
LoRARequest
]]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Union
[
int
,
torch
.
Tensor
]]:
tensor_dict
=
{
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
"token_type_ids"
:
self
.
token_type_ids
,
"multi_modal_kwargs"
:
self
.
multi_modal_kwargs
,
"lora_requests"
:
self
.
lora_requests
,
"lora_mapping"
:
self
.
lora_mapping
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
:
Type
[
TModelInputForCPU
],
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
)
->
TModelInputForCPU
:
if
attn_backend
is
not
None
:
tensor_dict
=
_init_attn_metadata_from_tensor_dict
(
attn_backend
,
tensor_dict
)
return
cls
(
**
tensor_dict
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForCPUWithSamplingMetadata
(
ModelInputForCPU
):
"""
Used by the ModelRunner.
"""
sampling_metadata
:
Optional
[
"SamplingMetadata"
]
=
None
is_prompt
:
Optional
[
bool
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
"token_type_ids"
:
self
.
token_type_ids
,
"multi_modal_kwargs"
:
self
.
multi_modal_kwargs
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
_add_sampling_metadata_broadcastable_dict
(
tensor_dict
,
self
.
sampling_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
,
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
"ModelInputForCPUWithSamplingMetadata"
:
tensor_dict
=
_init_sampling_metadata_from_tensor_dict
(
tensor_dict
)
if
attn_backend
is
not
None
:
tensor_dict
=
_init_attn_metadata_from_tensor_dict
(
attn_backend
,
tensor_dict
)
return
cls
(
**
tensor_dict
)
class
ModelInputForCPUBuilder
(
ModelRunnerInputBuilderBase
[
ModelInputForCPU
]):
class
ModelInputData
:
def
__init__
(
self
,
use_mrope
:
bool
):
self
.
use_mrope
=
use_mrope
self
.
input_tokens
:
List
[
int
]
=
[]
self
.
input_positions
:
List
[
int
]
=
[]
self
.
token_type_ids
:
Optional
[
List
[
int
]]
=
[]
self
.
seq_lens
:
List
[
int
]
=
[]
self
.
query_lens
:
List
[
int
]
=
[]
self
.
prefill_block_tables
:
List
[
List
[
int
]]
=
[]
self
.
decode_block_tables
:
List
[
List
[
int
]]
=
[]
self
.
max_decode_seq_len
:
int
=
0
self
.
num_prefills
:
int
=
0
self
.
num_prefill_tokens
:
int
=
0
self
.
num_decode_tokens
:
int
=
0
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
multi_modal_inputs_list
:
List
[
MultiModalKwargs
]
=
[]
self
.
multi_modal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
self
.
input_mrope_positions
:
List
[
List
[
int
]]
=
[[]
for
_
in
range
(
3
)]
def
__init__
(
self
,
runner
:
"CPUModelRunner"
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
super
().
__init__
()
self
.
runner
=
runner
self
.
chunked_prefill
=
(
runner
.
scheduler_config
.
chunked_prefill_enabled
or
runner
.
cache_config
.
enable_prefix_caching
)
self
.
model_input_cls
=
self
.
runner
.
_model_input_cls
self
.
attn_backend
=
self
.
runner
.
attn_backend
self
.
sliding_window
=
self
.
runner
.
sliding_window
self
.
block_size
=
self
.
runner
.
block_size
self
.
device
=
self
.
runner
.
device
self
.
enable_lora
=
self
.
runner
.
lora_config
is
not
None
if
self
.
runner
.
attn_backend
is
not
None
:
# spec decode (e.g. Medusa) does not have atten backend
attn_backend
=
self
.
runner
.
attn_backend
self
.
att_metadata_builder
=
attn_backend
.
get_builder_cls
()(
self
)
def
prepare
(
self
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
self
.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
self
.
input_data
=
ModelInputForCPUBuilder
.
ModelInputData
(
self
.
runner
.
model_config
.
uses_mrope
)
self
.
att_metadata_builder
.
prepare
()
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
self
.
seq_group_metadata_list
.
append
(
seq_group_metadata
)
def
set_seq_group_list
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]):
self
.
seq_group_metadata_list
=
seq_group_metadata_list
def
build
(
self
)
->
ModelInputForCPU
:
self
.
_build_input_data
()
input_data
=
self
.
input_data
input_tokens
=
torch
.
tensor
(
input_data
.
input_tokens
,
dtype
=
torch
.
long
,
device
=
"cpu"
)
input_positions
=
torch
.
tensor
(
input_data
.
input_positions
if
not
any
(
input_data
.
input_mrope_positions
)
else
input_data
.
input_mrope_positions
,
dtype
=
torch
.
long
,
device
=
"cpu"
)
token_type_ids
=
torch
.
tensor
(
input_data
.
token_type_ids
,
dtype
=
torch
.
long
,
device
=
"cpu"
)
\
if
input_data
.
token_type_ids
else
None
# For multi-modal models
multi_modal_kwargs
=
None
if
len
(
input_data
.
multi_modal_inputs_list
)
!=
0
:
multi_modal_kwargs
=
MultiModalKwargs
.
batch
(
input_data
.
multi_modal_inputs_list
)
attn_metadata
=
self
.
att_metadata_builder
.
build
(
input_data
.
seq_lens
,
input_data
.
query_lens
,
-
1
,
-
1
)
is_prompt
=
(
self
.
seq_group_metadata_list
[
0
].
is_prompt
if
self
.
seq_group_metadata_list
else
None
)
# LoRA data.
lora_requests
=
set
()
lora_mapping
=
None
if
self
.
enable_lora
:
lora_requests
=
set
(
seq
.
lora_request
for
seq
in
self
.
seq_group_metadata_list
if
seq
.
lora_request
is
not
None
)
lora_mapping
=
self
.
_prepare_lora_input
(
self
.
seq_group_metadata_list
,
is_prompt
)
return
self
.
model_input_cls
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
token_type_ids
=
token_type_ids
,
seq_lens
=
input_data
.
seq_lens
,
query_lens
=
input_data
.
query_lens
,
attn_metadata
=
attn_metadata
,
multi_modal_kwargs
=
multi_modal_kwargs
,
lora_mapping
=
lora_mapping
,
lora_requests
=
lora_requests
)
def
_build_input_data
(
self
):
for
seq_group_metadata
in
self
.
seq_group_metadata_list
:
for
seq_id
,
seq_data
in
seq_group_metadata
.
seq_data
.
items
():
if
seq_group_metadata
.
is_prompt
:
self
.
_compute_prompt_input_tokens
(
self
.
input_data
,
seq_group_metadata
,
seq_data
,
seq_id
)
if
seq_group_metadata
.
multi_modal_data
:
self
.
_compute_multi_modal_input
(
seq_group_metadata
,
seq_data
)
else
:
self
.
_compute_decode_input_tokens
(
self
.
input_data
,
seq_group_metadata
,
seq_data
,
seq_id
)
def
_compute_decode_input_tokens
(
self
,
data
:
ModelInputData
,
seq_group_metadata
:
SequenceGroupMetadata
,
seq_data
:
SequenceData
,
seq_id
:
int
):
"""
Compute decode input tokens, positions, block table and slot mapping.
"""
block_size
=
self
.
runner
.
block_size
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
seq_len
=
seq_data
.
get_len
()
context_len
=
seq_data
.
get_num_computed_tokens
()
tokens
=
seq_data
.
get_last_token_id
()
token_positions
=
seq_len
-
1
block_number
=
block_table
[
token_positions
//
block_size
]
block_offset
=
token_positions
%
block_size
slot
=
block_number
*
block_size
+
block_offset
# For paged_attention kernel
if
self
.
runner
.
sliding_window
:
start_idx
=
max
(
0
,
seq_len
-
self
.
runner
.
sliding_window
)
start_block
=
start_idx
//
block_size
start_idx
=
start_block
*
block_size
seq_len
=
seq_len
-
start_idx
block_table
=
block_table
[
start_block
:]
# For MRotaryEmbedding
if
seq_data
.
mrope_position_delta
is
not
None
:
next_pos
=
MRotaryEmbedding
.
get_next_input_positions
(
seq_data
.
mrope_position_delta
,
context_len
,
seq_len
,
)
for
idx
in
range
(
3
):
data
.
input_mrope_positions
[
idx
].
extend
(
# type: ignore
next_pos
[
idx
])
else
:
data
.
input_positions
.
append
(
token_positions
)
# type: ignore
# Update fields
data
.
input_tokens
.
append
(
tokens
)
data
.
max_decode_seq_len
=
max
(
data
.
max_decode_seq_len
,
seq_len
)
data
.
num_decode_tokens
+=
1
data
.
slot_mapping
.
append
(
slot
)
data
.
decode_block_tables
.
append
(
block_table
)
data
.
query_lens
.
append
(
1
)
data
.
seq_lens
.
append
(
seq_len
)
def
_compute_prompt_input_tokens
(
self
,
data
:
ModelInputData
,
seq_group_metadata
:
SequenceGroupMetadata
,
seq_data
:
SequenceData
,
seq_id
:
int
):
"""
Compute prompt input tokens, positions, block table and slot mapping.
"""
token_chunk_size
=
seq_group_metadata
.
token_chunk_size
block_size
=
self
.
runner
.
block_size
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
seq_len
=
seq_data
.
get_len
()
context_len
=
seq_data
.
get_num_computed_tokens
()
seq_len
=
min
(
seq_len
,
context_len
+
token_chunk_size
)
# For prefix caching
prefix_cache_block_num
=
len
(
seq_group_metadata
.
computed_block_nums
)
if
prefix_cache_block_num
>
0
:
prefix_cache_len
=
(
prefix_cache_block_num
*
self
.
runner
.
block_size
)
if
prefix_cache_len
<=
context_len
:
# We already passed the cache hit region,
# so do normal computation.
pass
elif
context_len
<
prefix_cache_len
<
seq_len
:
# Partial hit. Compute the missing part.
context_len
=
prefix_cache_len
token_chunk_size
=
seq_len
-
context_len
elif
seq_len
<=
prefix_cache_len
:
# Full hit. Only compute the last token to avoid
# erroneous behavior. FIXME: Ideally we should directly
# mark all tokens as computed in the scheduler and do not
# schedule this sequence, so this case should not happen.
context_len
=
seq_len
-
1
token_chunk_size
=
1
tokens
=
seq_data
.
get_token_ids
()
tokens
=
tokens
[
context_len
:
seq_len
]
token_positions
=
range
(
context_len
,
seq_len
)
token_types
=
seq_group_metadata
.
token_type_ids
# For encoder-only models, the block_table is None,
# and there is no need to initialize the slot_mapping.
if
block_table
is
not
None
:
slot_mapping
=
[
_PAD_SLOT_ID
]
*
len
(
token_positions
)
for
i
,
pos
in
enumerate
(
token_positions
):
block_number
=
block_table
[
pos
//
block_size
]
block_offset
=
pos
%
block_size
slot
=
block_number
*
block_size
+
block_offset
slot_mapping
[
i
]
=
slot
data
.
slot_mapping
.
extend
(
slot_mapping
)
# The MROPE positions are prepared in _compute_multi_modal_input
data
.
input_positions
.
extend
(
token_positions
)
if
data
.
token_type_ids
is
not
None
:
data
.
token_type_ids
.
extend
(
token_types
if
token_types
else
[])
# Update fields
data
.
input_tokens
.
extend
(
tokens
)
data
.
num_prefills
+=
1
data
.
num_prefill_tokens
+=
len
(
tokens
)
data
.
query_lens
.
append
(
len
(
tokens
))
data
.
prefill_block_tables
.
append
(
block_table
)
data
.
seq_lens
.
append
(
seq_len
)
def
_compute_multi_modal_input
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
seq_data
:
SequenceData
):
computed_len
=
seq_data
.
get_num_computed_tokens
()
seq_len
=
self
.
input_data
.
seq_lens
[
-
1
]
# NOTE: mm_kwargs only includes the subset of multi-modal items that
# intersect with the current prefill positions.
mm_kwargs
,
placeholder_maps
=
MultiModalPlaceholderMap
.
from_seq_group
(
seq_group_metadata
,
range
(
computed_len
,
seq_len
))
if
not
mm_kwargs
:
return
# special processing for mrope position deltas.
if
self
.
runner
.
model_config
.
uses_mrope
:
assert
not
self
.
chunked_prefill
,
\
"MROPE on CPU does not support chunked-prefill."
image_grid_thw
=
mm_kwargs
.
get
(
"image_grid_thw"
,
None
)
video_grid_thw
=
mm_kwargs
.
get
(
"video_grid_thw"
,
None
)
audio_feature_lengths
=
mm_kwargs
.
get
(
"audio_feature_lengths"
,
None
)
assert
(
image_grid_thw
is
not
None
or
video_grid_thw
is
not
None
or
audio_feature_lengths
is
not
None
),
(
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw' or "
"'audio_feature_lengths'."
)
second_per_grid_ts
=
mm_kwargs
.
get
(
"second_per_grid_ts"
,
None
)
use_audio_in_video
=
mm_kwargs
.
get
(
"use_audio_in_video"
,
False
)
hf_config
=
self
.
runner
.
model_config
.
hf_config
token_ids
=
seq_data
.
get_token_ids
()
mrope_positions
,
mrope_position_delta
=
\
MRotaryEmbedding
.
get_input_positions
(
token_ids
,
hf_config
=
hf_config
,
image_grid_thw
=
image_grid_thw
,
video_grid_thw
=
video_grid_thw
,
second_per_grid_ts
=
second_per_grid_ts
,
context_len
=
computed_len
,
audio_feature_lengths
=
audio_feature_lengths
,
use_audio_in_video
=
use_audio_in_video
,
)
seq_data
.
mrope_position_delta
=
mrope_position_delta
for
i
in
range
(
3
):
self
.
input_data
.
input_mrope_positions
[
# type: ignore
i
].
extend
(
mrope_positions
[
i
])
self
.
input_data
.
multi_modal_inputs_list
.
append
(
mm_kwargs
)
for
modality
,
placeholder_map
in
placeholder_maps
.
items
():
self
.
input_data
.
multi_modal_placeholder_maps
[
modality
].
extend
(
placeholder_map
)
def
_prepare_lora_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
is_prefill
:
bool
)
->
LoRAMapping
:
index_mapping
=
[]
prompt_mapping
=
[]
for
seq
in
seq_group_metadata_list
:
lora_id
=
seq
.
lora_int_id
query_len
=
seq
.
token_chunk_size
index_mapping
+=
[
lora_id
]
*
query_len
prompt_mapping
+=
[
lora_id
]
*
(
query_len
if
seq
.
sampling_params
and
seq
.
sampling_params
.
prompt_logprobs
is
not
None
else
1
)
return
LoRAMapping
(
index_mapping
=
tuple
(
index_mapping
),
prompt_mapping
=
tuple
(
prompt_mapping
),
is_prefill
=
is_prefill
)
class
CPUModelRunnerBase
(
ModelRunnerBase
[
TModelInputForCPU
]):
"""
Helper class for shared methods between CPU model runners.
"""
_model_input_cls
:
Type
[
TModelInputForCPU
]
_builder_cls
:
Type
[
ModelInputForCPUBuilder
]
builder
:
ModelInputForCPUBuilder
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
return_hidden_states
:
bool
=
False
,
*
args
,
**
kwargs
,
):
ModelRunnerBase
.
__init__
(
self
,
vllm_config
)
model_config
=
self
.
model_config
cache_config
=
self
.
cache_config
self
.
is_driver_worker
=
is_driver_worker
self
.
return_hidden_states
=
return_hidden_states
self
.
device
=
self
.
device_config
.
device
self
.
pin_memory
=
False
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
num_attn_heads
=
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
)
needs_attn_backend
=
(
num_attn_heads
!=
0
or
self
.
model_config
.
is_attention_free
)
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
use_mla
=
self
.
model_config
.
use_mla
,
)
if
needs_attn_backend
else
None
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
# Set after load_model.
self
.
lora_manager
:
Optional
[
LRUCacheWorkerLoRAManager
]
=
None
self
.
sampler
=
get_sampler
()
if
hasattr
(
self
,
"_builder_cls"
):
# multi-step model runner does not have `_builder_cls`
self
.
builder
=
self
.
_builder_cls
(
weakref
.
proxy
(
self
))
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
if
self
.
lora_config
:
assert
supports_lora
(
self
.
model
),
f
"
{
self
.
model
.
__class__
.
__name__
}
does not support LoRA yet."
if
supports_multimodal
(
self
.
model
):
logger
.
warning
(
"Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model."
)
# Use get_text_config() in case of multimodal models
text_config
=
self
.
model_config
.
hf_config
.
get_text_config
()
self
.
lora_manager
=
LRUCacheWorkerLoRAManager
(
self
.
scheduler_config
.
max_num_seqs
,
self
.
scheduler_config
.
max_num_batched_tokens
,
self
.
vocab_size
,
self
.
lora_config
,
self
.
device
,
self
.
model
.
embedding_modules
,
self
.
model
.
embedding_padding_modules
,
max_position_embeddings
=
text_config
.
max_position_embeddings
,
)
self
.
model
=
self
.
lora_manager
.
create_lora_manager
(
self
.
model
)
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model
def
_prepare_model_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
TModelInputForCPU
:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
"""
self
.
builder
.
prepare
(
finished_requests_ids
)
self
.
builder
.
set_seq_group_list
(
seq_group_metadata_list
)
return
self
.
builder
.
build
()
# type: ignore
@
property
def
vocab_size
(
self
)
->
int
:
return
self
.
model_config
.
get_vocab_size
()
def
remove_all_loras
(
self
):
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
self
.
lora_manager
.
remove_all_adapters
()
def
set_active_loras
(
self
,
lora_requests
:
Set
[
LoRARequest
],
lora_mapping
:
LoRAMapping
)
->
None
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
self
.
lora_manager
.
set_active_adapters
(
lora_requests
,
lora_mapping
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
add_adapter
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
remove_adapter
(
lora_id
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
pin_adapter
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
list_adapters
()
class
CPUModelRunner
(
CPUModelRunnerBase
[
ModelInputForCPUWithSamplingMetadata
]):
_model_input_cls
:
Type
[
ModelInputForCPUWithSamplingMetadata
]
=
(
ModelInputForCPUWithSamplingMetadata
)
_builder_cls
:
Type
[
ModelInputForCPUBuilder
]
=
ModelInputForCPUBuilder
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
],
)
->
ModelInputForCPUWithSamplingMetadata
:
return
ModelInputForCPUWithSamplingMetadata
.
from_broadcasted_tensor_dict
(
# noqa: E501
tensor_dict
,
attn_backend
=
self
.
attn_backend
,
)
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
ModelInputForCPUWithSamplingMetadata
:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
"""
model_input
=
self
.
_prepare_model_input_tensors
(
seq_group_metadata_list
,
finished_requests_ids
)
# Sampling metadata is only required for the final pp group
generators
=
self
.
get_generators
(
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
model_input
.
seq_lens
,
model_input
.
query_lens
,
self
.
device
,
pin_memory
=
False
,
generators
=
generators
)
is_prompt
=
(
seq_group_metadata_list
[
0
].
is_prompt
if
seq_group_metadata_list
else
None
)
return
dataclasses
.
replace
(
model_input
,
sampling_metadata
=
sampling_metadata
,
virtual_engine
=
virtual_engine
,
is_prompt
=
is_prompt
)
@
torch
.
no_grad
()
def
execute_model
(
self
,
model_input
:
ModelInputForCPUWithSamplingMetadata
,
kv_caches
:
List
[
torch
.
Tensor
],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
previous_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Optional
[
List
[
SamplerOutput
]]:
if
num_steps
>
1
:
raise
ValueError
(
"CPU worker does not support multi-step execution."
)
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
model_executable
=
self
.
model
multimodal_kwargs
=
{}
if
model_input
.
multi_modal_kwargs
is
not
None
:
multimodal_kwargs
=
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
,
device
=
self
.
device
,
)
execute_model_kwargs
=
{}
if
previous_hidden_states
is
not
None
:
execute_model_kwargs
.
update
(
{
"previous_hidden_states"
:
previous_hidden_states
})
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
hidden_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
intermediate_tensors
=
intermediate_tensors
,
**
execute_model_kwargs
,
**
multimodal_kwargs
,
)
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
model_input
.
sampling_metadata
)
# Only perform sampling in the driver worker.
if
not
self
.
is_driver_worker
:
return
[]
# Sample the next token.
output
=
self
.
sampler
(
logits
=
logits
,
sampling_metadata
=
model_input
.
sampling_metadata
,
)
if
self
.
return_hidden_states
:
# we only need to pass hidden states of most recent token
if
model_input
.
is_prompt
:
output
.
prefill_hidden_states
=
hidden_states
output
.
hidden_states
=
hidden_states
return
[
output
]
def
generate_proposals
(
self
,
*
args
,
**
kwargs
):
return
self
.
model
.
generate_proposals
(
*
args
,
**
kwargs
)
vllm/worker/cpu_pooling_model_runner.py
deleted
100644 → 0
View file @
3b2aefb1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
torch
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.pooling_params
import
PoolingParams
from
vllm.sequence
import
(
IntermediateTensors
,
PoolerOutput
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.worker.cpu_model_runner
import
(
CPUModelRunnerBase
,
ModelInputForCPU
,
ModelInputForCPUBuilder
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
ModelInputForCPUWithPoolingMetadata
(
ModelInputForCPU
):
"""
Used by the CPUPoolingModelRunner.
"""
pooling_metadata
:
Optional
[
"PoolingMetadata"
]
=
None
class
CPUPoolingModelRunner
(
CPUModelRunnerBase
[
ModelInputForCPUWithPoolingMetadata
]):
_model_input_cls
:
Type
[
ModelInputForCPUWithPoolingMetadata
]
=
(
ModelInputForCPUWithPoolingMetadata
)
_builder_cls
:
Type
[
ModelInputForCPUBuilder
]
=
ModelInputForCPUBuilder
@
torch
.
inference_mode
()
def
execute_model
(
self
,
model_input
:
ModelInputForCPUWithPoolingMetadata
,
kv_caches
:
List
[
torch
.
Tensor
],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
)
->
Optional
[
Union
[
List
[
PoolerOutput
],
IntermediateTensors
]]:
if
num_steps
>
1
:
raise
ValueError
(
"CPU worker does not support multi-step execution."
)
model_executable
=
self
.
model
cross_enc_kwargs
=
{}
if
model_input
.
token_type_ids
is
not
None
:
cross_enc_kwargs
[
"token_type_ids"
]
=
model_input
.
token_type_ids
execute_model_kwargs
=
{
"input_ids"
:
model_input
.
input_tokens
,
"positions"
:
model_input
.
input_positions
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
device
=
self
.
device
,
),
**
cross_enc_kwargs
,
"intermediate_tensors"
:
intermediate_tensors
,
}
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Only perform pooling in the driver worker.
if
not
self
.
is_driver_worker
:
return
[]
return
[
self
.
model
.
pooler
(
hidden_states
=
hidden_states
,
pooling_metadata
=
model_input
.
pooling_metadata
)
]
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
ModelInputForCPUWithPoolingMetadata
:
return
ModelInputForCPUWithPoolingMetadata
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
self
.
attn_backend
,
)
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
ModelInputForCPUWithPoolingMetadata
:
assert
seq_group_metadata_list
is
not
None
model_input
=
self
.
_prepare_model_input_tensors
(
seq_group_metadata_list
,
finished_requests_ids
)
# Prepare PoolingMetadata.
assert
model_input
.
seq_lens
is
not
None
pooling_metadata
=
self
.
_prepare_pooling
(
seq_group_metadata_list
,
model_input
.
seq_lens
)
return
dataclasses
.
replace
(
model_input
,
virtual_engine
=
virtual_engine
,
pooling_metadata
=
pooling_metadata
)
def
_prepare_pooling
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt_lens
:
List
[
int
],
)
->
PoolingMetadata
:
"""Prepare PoolingMetadata for the sequence group metadata list."""
seq_groups
:
List
[
Tuple
[
List
[
int
],
PoolingParams
]]
=
[]
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
pooling_params
=
seq_group_metadata
.
pooling_params
seq_groups
.
append
((
seq_ids
,
pooling_params
))
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
.
update
(
seq_group_metadata
.
seq_data
)
pooling_metadata
=
PoolingMetadata
(
seq_groups
=
seq_groups
,
seq_data
=
seq_data
,
prompt_lens
=
prompt_lens
,
)
return
pooling_metadata
vllm/worker/cpu_worker.py
deleted
100644 → 0
View file @
3b2aefb1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A CPU worker class."""
import
os
from
importlib
import
util
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Type
import
torch
import
torch.distributed
import
vllm.envs
as
envs
from
vllm.attention
import
get_attn_backend
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
VllmConfig
)
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
bind_kv_cache
from
vllm.worker.cpu_enc_dec_model_runner
import
CPUEncoderDecoderModelRunner
from
vllm.worker.cpu_model_runner
import
CPUModelRunner
,
CPUModelRunnerBase
from
vllm.worker.cpu_pooling_model_runner
import
CPUPoolingModelRunner
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
WorkerBase
,
WorkerInput
)
logger
=
init_logger
(
__name__
)
class
CPUCacheEngine
:
"""Manages the KV cache for CPU backend.
This class is responsible for initializing and managing CPU KV
caches. It also provides methods for performing KV cache operations, such
as copying.
"""
def
__init__
(
self
,
cache_config
:
CacheConfig
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
device_config
:
DeviceConfig
)
->
None
:
assert
device_config
.
device_type
==
"cpu"
self
.
cache_config
=
cache_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
head_size
=
model_config
.
get_head_size
()
self
.
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
self
.
num_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
block_size
=
cache_config
.
block_size
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
# for CPU backend, because we want to reuse KV cache management
# in the scheduler.
self
.
num_cpu_blocks
=
cache_config
.
num_gpu_blocks
self
.
dtype
=
CPUCacheEngine
.
get_kv_cache_dtype
(
cache_config
,
model_config
)
# Get attention backend.
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
dtype
,
cache_config
.
cache_dtype
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
use_mla
=
self
.
model_config
.
use_mla
,
)
# Initialize the cache.
self
.
cpu_cache
=
self
.
_allocate_kv_cache
(
self
.
num_cpu_blocks
)
def
_allocate_kv_cache
(
self
,
num_blocks
:
int
,
)
->
List
[
torch
.
Tensor
]:
"""Allocates KV cache on CPU."""
kv_cache_shape
=
self
.
attn_backend
.
get_kv_cache_shape
(
num_blocks
,
self
.
block_size
,
self
.
num_heads
,
self
.
head_size
)
kv_cache
:
List
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
self
.
num_layers
):
kv_cache
.
append
(
torch
.
empty
(
kv_cache_shape
,
dtype
=
self
.
dtype
,
device
=
"cpu"
))
return
kv_cache
def
swap_in
(
self
,
src_to_dst
:
torch
.
Tensor
)
->
None
:
raise
NotImplementedError
(
"Swap is not supported in CPUCacheEngine."
)
def
swap_out
(
self
,
src_to_dst
:
torch
.
Tensor
)
->
None
:
raise
NotImplementedError
(
"Swap is not supported in CPUCacheEngine."
)
def
copy
(
self
,
src_to_dsts
:
torch
.
Tensor
)
->
None
:
self
.
attn_backend
.
copy_blocks
(
self
.
cpu_cache
,
src_to_dsts
)
@
staticmethod
def
get_kv_cache_dtype
(
cache_config
:
CacheConfig
,
model_config
:
ModelConfig
):
if
cache_config
.
cache_dtype
==
"auto"
:
return
model_config
.
dtype
elif
cache_config
.
cache_dtype
in
[
"fp8"
,
"fp8_e5m2"
]:
return
torch
.
float8_e5m2
else
:
raise
NotImplementedError
(
f
"Unsupported KV cache type "
f
"
{
cache_config
.
cache_dtype
}
."
)
@
staticmethod
def
get_cache_block_size
(
cache_config
:
CacheConfig
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
)
->
int
:
head_size
=
model_config
.
get_head_size
()
num_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
num_layers
=
model_config
.
get_num_layers
(
parallel_config
)
key_cache_block
=
cache_config
.
block_size
*
num_heads
*
head_size
value_cache_block
=
key_cache_block
if
not
model_config
.
use_mla
else
0
total
=
num_layers
*
(
key_cache_block
+
value_cache_block
)
dtype
=
CPUCacheEngine
.
get_kv_cache_dtype
(
cache_config
,
model_config
)
dtype_size
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
return
dtype_size
*
total
class
CPUWorker
(
LocalOrDistributedWorkerBase
):
"""A worker class that executes (a partition of) the model on a CPU socket.
Each worker is associated with a single CPU socket. The worker is
responsible for maintaining the KV cache and executing the model on the
CPU. In case of distributed inference, each worker is assigned a partition
of the model.
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
model_runner_cls
:
Optional
[
Type
[
CPUModelRunner
]]
=
None
,
)
->
None
:
WorkerBase
.
__init__
(
self
,
vllm_config
=
vllm_config
)
self
.
local_rank
=
local_rank
self
.
rank
=
rank
vllm_config
.
parallel_config
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
is_driver_worker
=
is_driver_worker
if
self
.
is_driver_worker
:
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
if
self
.
model_config
.
trust_remote_code
:
# note: lazy import to avoid importing torch before initializing
from
vllm.utils
import
init_cached_hf_modules
init_cached_hf_modules
()
# Setup OpenMP threads affinity.
omp_cpuids
=
envs
.
VLLM_CPU_OMP_THREADS_BIND
self
.
local_omp_cpuid
=
"all"
if
omp_cpuids
==
"auto"
:
self
.
local_omp_cpuid
=
self
.
get_cpus_id_binding_based_on_numa_nodes
(
)
else
:
self
.
local_omp_cpuid
=
omp_cpuids
.
split
(
"|"
)[
rank
]
# Return hidden states from target model if the draft model is an
# mlp_speculator
speculative_config
=
self
.
speculative_config
model_config
=
self
.
model_config
speculative_args
=
{}
if
speculative_config
is
None
\
or
(
speculative_config
.
draft_model_config
.
model
==
model_config
.
model
)
\
or
(
speculative_config
.
draft_model_config
.
hf_config
.
model_type
not
in
[
"medusa"
,
"mlp_speculator"
,
"eagle"
])
\
else
{
"return_hidden_states"
:
True
}
ModelRunnerClass
:
Type
[
CPUModelRunnerBase
]
=
CPUModelRunner
if
self
.
model_config
.
runner_type
==
"pooling"
:
ModelRunnerClass
=
CPUPoolingModelRunner
elif
self
.
model_config
.
is_encoder_decoder
:
ModelRunnerClass
=
CPUEncoderDecoderModelRunner
self
.
model_runner
:
CPUModelRunnerBase
=
ModelRunnerClass
(
vllm_config
=
vllm_config
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
is_driver_worker
,
**
speculative_args
,
)
if
model_runner_cls
is
not
None
:
self
.
model_runner
=
model_runner_cls
(
self
.
model_runner
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self
.
cache_engine
:
List
[
CPUCacheEngine
]
# Initialize cpu_cache as pooling models don't initialize kv_caches
self
.
cpu_cache
:
Optional
[
List
[
List
[
torch
.
Tensor
]]]
=
None
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if
envs
.
VLLM_TORCH_PROFILER_DIR
:
torch_profiler_trace_dir
=
envs
.
VLLM_TORCH_PROFILER_DIR
logger
.
info
(
"Profiling enabled. Traces will be saved to: %s"
,
torch_profiler_trace_dir
)
self
.
profiler
=
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CPU
,
],
with_stack
=
True
,
on_trace_ready
=
torch
.
profiler
.
tensorboard_trace_handler
(
torch_profiler_trace_dir
,
use_gzip
=
True
))
else
:
self
.
profiler
=
None
def
start_profile
(
self
):
if
self
.
profiler
is
None
:
raise
RuntimeError
(
"Profiler is not enabled."
)
self
.
profiler
.
start
()
def
stop_profile
(
self
):
if
self
.
profiler
is
None
:
raise
RuntimeError
(
"Profiler is not enabled."
)
self
.
profiler
.
stop
()
def
init_device
(
self
)
->
None
:
if
self
.
local_omp_cpuid
!=
"all"
:
ret
=
torch
.
ops
.
_C_utils
.
init_cpu_threads_env
(
self
.
local_omp_cpuid
)
if
ret
:
logger
.
info
(
ret
)
# Note: unique identifier for creating allreduce shared memory
os
.
environ
[
"VLLM_DIST_IDENT"
]
=
self
.
distributed_init_method
.
split
(
":"
)[
-
1
]
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
init_distributed_environment
()
# Set random seed.
set_random_seed
(
self
.
model_config
.
seed
)
def
load_model
(
self
):
self
.
model_runner
.
load_model
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of blocks available for the KV cache.
This determines how many KV blocks can fit into the configured CPU
KV cache space.
Note that since vLLM assumes a block resides on GPU if it can be
modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
This allows us to reuse the scheduler of vLLM without generalizing it
to different devices.
"""
# For CPU device, the block number will be calculated based on the
# cpu_kvcache_space.
cache_block_size
=
self
.
get_cache_block_size_bytes
()
num_cpu_blocks
=
int
(
self
.
cache_config
.
cpu_kvcache_space_bytes
//
cache_block_size
)
num_cpu_blocks
=
max
(
num_cpu_blocks
,
0
)
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_gpu_blocks
=
num_cpu_blocks
num_cpu_blocks
=
0
return
num_gpu_blocks
,
num_cpu_blocks
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""Initialize the KV cache. Currently, swappable CPU memory is not
supported.
Since this worker does not support GPUs, we use the num_gpu_blocks to
determine how many non-swappable CPU blocks to allocate.
"""
assert
(
num_cpu_blocks
==
0
),
f
"
{
type
(
self
)
}
does not support swappable cache"
# Note: To reuse the cache management procedure,
# use cpu cache as 'gpu cache'.
num_cpu_blocks
=
num_gpu_blocks
self
.
_validate_num_cpu_blocks
(
num_cpu_blocks
)
self
.
cache_config
.
num_gpu_blocks
=
num_cpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
0
# Initialize the cache.
self
.
_init_cache_engine
()
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
return
self
.
model_runner
.
add_lora
(
lora_request
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_runner
.
remove_lora
(
lora_id
)
def
pin_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
model_runner
.
pin_lora
(
lora_id
)
def
list_loras
(
self
)
->
Set
[
int
]:
return
self
.
model_runner
.
list_loras
()
def
_validate_num_cpu_blocks
(
self
,
num_cpu_blocks
:
int
)
->
None
:
"""Raise errors if the num_cpu_blocks is invalid.
"""
if
num_cpu_blocks
<=
0
:
raise
ValueError
(
"No available memory for the cache blocks. "
"Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
"initializing the engine."
)
max_seq_len
=
self
.
cache_config
.
block_size
*
num_cpu_blocks
if
self
.
model_config
.
max_model_len
>
max_seq_len
:
raise
ValueError
(
f
"The model's max seq len (
{
self
.
model_config
.
max_model_len
}
) "
"is larger than the maximum number of tokens that can be "
f
"stored in KV cache (
{
max_seq_len
}
). Try increasing "
"`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
"initializing the engine."
)
def
_init_cache_engine
(
self
)
->
None
:
self
.
cache_engine
=
[
CPUCacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
parallel_config
,
self
.
device_config
)
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
self
.
cpu_cache
=
[
self
.
cache_engine
[
ve
].
cpu_cache
for
ve
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
self
.
cpu_cache
)
self
.
model_runner
.
block_size
=
self
.
cache_engine
[
0
].
block_size
assert
all
(
self
.
cpu_cache
[
ve
]
is
not
None
for
ve
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
))
# Populate the cache to warmup the memory
for
ve
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
):
for
layer_cache
in
self
.
cpu_cache
[
ve
]:
layer_cache
.
fill_
(
0
)
@
property
def
do_metadata_broadcast
(
self
)
->
bool
:
return
self
.
parallel_config
.
tensor_parallel_size
>
1
@
property
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
return
self
.
cpu_cache
@
property
def
vocab_size
(
self
)
->
int
:
return
self
.
model_runner
.
vocab_size
@
property
def
max_model_len
(
self
)
->
int
:
return
self
.
model_config
.
max_model_len
def
execute_worker
(
self
,
worker_input
:
WorkerInput
,
)
->
None
:
if
(
worker_input
.
blocks_to_copy
is
not
None
and
worker_input
.
blocks_to_copy
.
numel
()
>
0
):
self
.
cache_engine
[
worker_input
.
virtual_engine
].
copy
(
worker_input
.
blocks_to_copy
)
@
torch
.
inference_mode
()
def
prepare_worker_input
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
WorkerInput
:
assert
execute_model_req
is
not
None
virtual_engine
:
int
=
execute_model_req
.
virtual_engine
num_seq_groups
:
int
=
len
(
execute_model_req
.
seq_group_metadata_list
)
blocks_to_copy
=
torch
.
tensor
(
execute_model_req
.
blocks_to_copy
,
device
=
"cpu"
,
dtype
=
torch
.
int64
).
view
(
-
1
,
2
)
assert
len
(
execute_model_req
.
blocks_to_swap_in
)
==
0
assert
len
(
execute_model_req
.
blocks_to_swap_out
)
==
0
return
WorkerInput
(
num_seq_groups
=
num_seq_groups
,
blocks_to_copy
=
blocks_to_copy
,
virtual_engine
=
virtual_engine
,
)
def
init_distributed_environment
(
self
)
->
None
:
"""Initialize the distributed environment."""
parallel_config
=
self
.
parallel_config
rank
=
self
.
rank
distributed_init_method
=
self
.
distributed_init_method
init_distributed_environment
(
world_size
=
parallel_config
.
world_size
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
backend
=
"gloo"
,
)
# A small all_reduce for warmup.
torch
.
distributed
.
all_reduce
(
torch
.
zeros
(
1
).
cpu
())
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
def
get_cache_block_size_bytes
(
self
)
->
int
:
"""Return the size in bytes of a single KV cache block.
"""
return
CPUCacheEngine
.
get_cache_block_size
(
self
.
cache_config
,
self
.
model_config
,
self
.
parallel_config
)
def
get_cpus_id_binding_based_on_numa_nodes
(
self
)
->
str
:
"""Return CPUs id binding based on NUMA nodes.
"""
rank_to_cpus
=
self
.
local_omp_cpuid
# Setup OpenMP thread affinity based on NUMA nodes automatically
world_size
=
self
.
vllm_config
.
parallel_config
.
world_size
libnuma_found
=
util
.
find_spec
(
"numa"
)
is
not
None
psutil_found
=
util
.
find_spec
(
"psutil"
)
is
not
None
if
libnuma_found
and
psutil_found
:
import
psutil
from
numa
import
info
cpu_count
=
psutil
.
cpu_count
(
logical
=
False
)
cpus_allow_list
=
psutil
.
Process
().
cpu_affinity
()
numa_size
=
info
.
get_num_configured_nodes
()
cpu_count_per_numa
=
cpu_count
//
numa_size
num_of_reserved_cpu
=
min
(
envs
.
VLLM_CPU_NUM_OF_RESERVED_CPU
,
cpu_count_per_numa
//
2
)
# check allow node_to_cpus list
node_to_cpus
=
[]
for
i
in
range
(
numa_size
):
node_intersect
=
set
(
info
.
node_to_cpus
(
i
)).
intersection
(
cpus_allow_list
)
if
bool
(
node_intersect
):
node_to_cpus
.
append
(
list
(
node_intersect
))
if
world_size
>
len
(
node_to_cpus
):
logger
.
error
(
"Auto thread-binding failed due to "
"world size: %d is larger than "
"allowed NUMA nodes number: %d."
"Please try to bind threads manually."
,
world_size
,
len
(
node_to_cpus
))
else
:
end
=
cpu_count_per_numa
-
num_of_reserved_cpu
rank_to_cpus_list
=
node_to_cpus
[
self
.
rank
][:
end
]
rank_to_cpus
=
','
.
join
(
str
(
x
)
for
x
in
rank_to_cpus_list
)
logger
.
info
(
"auto thread-binding list: %s"
,
rank_to_cpus
)
else
:
logger
.
warning
(
"Auto thread-binding is not supported due to "
"the lack of package numa and psutil,"
"fallback to no thread-binding. To get better performance,"
"please try to manually bind threads."
)
return
rank_to_cpus
vllm/worker/multi_step_tpu_worker.py
deleted
100644 → 0
View file @
3b2aefb1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
from
typing
import
Dict
,
Optional
,
Tuple
import
torch
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.worker.tpu_model_runner
import
ModelInputForTPU
from
vllm.worker.tpu_worker
import
TPUWorker
from
vllm.worker.worker_base
import
WorkerInput
class
MultiStepTPUWorker
(
TPUWorker
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
cached_model_input
:
Optional
[
ModelInputForTPU
]
=
None
def
_get_driver_input_and_broadcast
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
Tuple
[
ModelInputForTPU
,
WorkerInput
,
Dict
[
str
,
torch
.
Tensor
]]:
assert
self
.
is_driver_worker
assert
execute_model_req
.
virtual_engine
==
0
is_first_multi_step
=
execute_model_req
.
is_first_multi_step
is_last_step
=
execute_model_req
.
is_last_step
if
is_first_multi_step
:
worker_input
:
WorkerInput
=
self
.
prepare_worker_input
(
execute_model_req
=
execute_model_req
)
worker_input
=
dataclasses
.
replace
(
worker_input
,
num_steps
=
execute_model_req
.
num_lookahead_slots
+
1
)
model_input
:
ModelInputForTPU
=
(
self
.
model_runner
.
prepare_model_input
(
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
virtual_engine
,
execute_model_req
.
finished_requests_ids
))
if
execute_model_req
.
async_callback
:
model_input
=
dataclasses
.
replace
(
model_input
,
async_callback
=
execute_model_req
.
async_callback
)
else
:
assert
self
.
cached_model_input
is
not
None
model_input
=
self
.
cached_model_input
worker_input
=
WorkerInput
()
model_input
=
dataclasses
.
replace
(
model_input
,
is_first_multi_step
=
is_first_multi_step
,
is_last_step
=
is_last_step
)
if
self
.
do_metadata_broadcast
:
if
is_first_multi_step
:
broadcast_data
=
worker_input
.
as_broadcastable_tensor_dict
()
broadcast_data
.
update
(
model_input
.
as_broadcastable_tensor_dict
())
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
else
:
broadcast_data
=
{
"is_first_multi_step"
:
is_first_multi_step
,
"is_last_step"
:
is_last_step
,
}
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
# Retuning empty dict here to keep this compatible with
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
return
model_input
,
worker_input
,
{}
def
prepare_input
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
,
)
->
Optional
[
Tuple
[
ModelInputForTPU
,
WorkerInput
,
Dict
[
str
,
torch
.
Tensor
]]]:
if
self
.
is_driver_worker
:
if
execute_model_req
is
None
:
if
self
.
do_metadata_broadcast
:
broadcast_tensor_dict
({},
src
=
0
)
return
None
model_input
,
worker_input
,
_
=
self
.
_get_driver_input_and_broadcast
(
execute_model_req
)
if
model_input
.
is_first_multi_step
:
self
.
cached_model_input
=
model_input
return
model_input
,
worker_input
,
{}
else
:
broadcast_data
=
broadcast_tensor_dict
(
src
=
0
)
if
not
broadcast_data
:
return
None
if
len
(
broadcast_data
)
==
2
:
assert
self
.
cached_model_input
is
not
None
self
.
cached_model_input
=
dataclasses
.
replace
(
self
.
cached_model_input
,
is_first_multi_step
=
broadcast_data
[
"is_first_multi_step"
],
is_last_step
=
broadcast_data
[
"is_last_step"
])
empty_worker_input
=
WorkerInput
()
return
self
.
cached_model_input
,
empty_worker_input
,
{}
worker_input
=
WorkerInput
.
from_broadcasted_tensor_dict
(
broadcast_data
)
model_input
=
(
self
.
model_runner
.
make_model_input_from_broadcasted_tensor_dict
(
broadcast_data
))
self
.
cached_model_input
=
model_input
return
model_input
,
worker_input
,
{}
vllm/worker/tpu_model_runner.py
deleted
100644 → 0
View file @
3b2aefb1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
enum
import
time
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
)
from
unittest.mock
import
patch
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch_xla.core.xla_model
as
xm
import
torch_xla.runtime
as
xr
from
vllm.attention
import
AttentionMetadata
,
get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
IntermediateTensors
,
Logprob
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
_add_attn_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
)
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
logger
=
init_logger
(
__name__
)
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID
=
1_000_000_000
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P
=
False
# FIXME(woosuk): A temporary hack to support `n > 1`.
# This can significantly affect the performance if too large.
_MAX_NUM_SAMPLES
=
128
class
ExecutionMode
(
enum
.
Enum
):
PREFILL
=
enum
.
auto
()
DECODE
=
enum
.
auto
()
PREFIX_PREFILL
=
enum
.
auto
()
def
is_prefill
(
self
)
->
bool
:
return
self
in
(
ExecutionMode
.
PREFILL
,
ExecutionMode
.
PREFIX_PREFILL
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForTPU
(
ModelRunnerInputBase
):
token_ids
:
torch
.
Tensor
position_ids
:
torch
.
Tensor
attn_metadata
:
AttentionMetadata
input_lens
:
torch
.
Tensor
t
:
torch
.
Tensor
p
:
torch
.
Tensor
num_samples
:
int
n
:
List
[
int
]
seq_groups
:
List
[
List
[
int
]]
is_first_multi_step
:
bool
=
True
is_last_step
:
bool
=
True
virtual_engine
:
int
=
0
async_callback
:
Optional
[
Callable
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Union
[
int
,
torch
.
Tensor
]]:
tensor_dict
=
{
"token_ids"
:
self
.
token_ids
,
"position_ids"
:
self
.
position_ids
,
"input_lens"
:
self
.
input_lens
,
"t"
:
self
.
t
,
"p"
:
self
.
p
,
"num_samples"
:
self
.
num_samples
,
"n"
:
self
.
n
,
"seq_groups"
:
self
.
seq_groups
,
"is_first_multi_step"
:
self
.
is_first_multi_step
,
"is_last_step"
:
self
.
is_last_step
,
"virtual_engine"
:
self
.
virtual_engine
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
:
Type
[
"ModelInputForTPU"
],
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
"ModelInputForTPU"
:
if
attn_backend
is
not
None
:
tensor_dict
=
_init_attn_metadata_from_tensor_dict
(
attn_backend
,
tensor_dict
)
return
cls
(
**
tensor_dict
)
class
TPUModelRunner
(
ModelRunnerBase
[
ModelInputForTPU
]):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
is_driver_worker
:
bool
=
False
,
):
ModelRunnerBase
.
__init__
(
self
,
vllm_config
=
vllm_config
)
self
.
is_driver_worker
=
is_driver_worker
self
.
block_size
=
self
.
cache_config
.
block_size
self
.
max_num_blocks_per_seq
=
(
self
.
model_config
.
max_model_len
//
self
.
block_size
)
self
.
block_tables
=
np
.
zeros
(
(
self
.
scheduler_config
.
max_num_seqs
,
self
.
max_num_blocks_per_seq
),
dtype
=
np
.
int32
)
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
dtype
,
self
.
cache_config
.
cache_dtype
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
False
,
)
self
.
cached_step_outputs
:
List
[
torch
.
Tensor
]
=
[]
smem_size
=
512
*
1024
block_table_size
=
4
*
self
.
block_tables
.
size
if
block_table_size
>=
smem_size
:
logger
.
warning
(
"The max_model_len (%d) is too large. This may degrade the "
"performance due to the insufficient smem size. Consider "
"setting --max-model-len to a smaller value, like %d."
,
self
.
model_config
.
max_model_len
,
self
.
model_config
.
max_model_len
/
(
block_table_size
/
smem_size
))
def
load_model
(
self
)
->
None
:
self
.
device
=
self
.
device_config
.
device
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
# process, the ranks can be different from the ranks internally assigned
# by the xm runtime. Therefore, there is a mismatch in the rank
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
# This is not a problem in linear layers because all-reduce is
# rank-agnostic. However, it matters for all-gather as the ranks
# determine the order of concatenating the output tensors.
# As a workaround, we use the xm's rank assignment only when loading
# the embedding weights.
xm_tp_rank
=
xr
.
global_ordinal
()
with
patch
(
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank"
,
return_value
=
xm_tp_rank
):
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
model
=
model
.
eval
()
xm
.
wait_device_ops
()
model
=
ModelWrapper
(
model
)
self
.
model
=
torch
.
compile
(
model
,
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model
.
model
def
_dummy_run
(
self
,
batch_size
:
int
,
seq_len
:
int
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
exec_mode
:
ExecutionMode
,
)
->
None
:
exec_mode
=
ExecutionMode
(
exec_mode
)
if
exec_mode
.
is_prefill
():
seq_len
=
(
seq_len
+
15
)
//
16
*
16
token_ids
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slot_mapping
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
input_lens
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
if
exec_mode
==
ExecutionMode
.
PREFILL
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
batch_size
,
num_prefill_tokens
=
batch_size
*
seq_len
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
block_tables
=
None
,
context_lens
=
None
,
effective_query_lens
=
None
,
)
else
:
context_lens
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
block_tables
=
torch
.
tensor
(
self
.
block_tables
[:
batch_size
],
dtype
=
torch
.
int32
,
device
=
self
.
device
)
effective_query_lens
=
torch
.
ones_like
(
context_lens
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
batch_size
,
num_prefill_tokens
=
batch_size
*
seq_len
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
block_tables
=
block_tables
,
context_lens
=
context_lens
,
effective_query_lens
=
effective_query_lens
,
)
else
:
assert
seq_len
==
1
token_ids
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
position_ids
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
slot_mapping
=
torch
.
zeros
((
batch_size
,
seq_len
),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
block_tables
=
torch
.
zeros
(
(
batch_size
,
self
.
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
context_lens
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
input_lens
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
*
seq_len
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
block_tables
=
block_tables
,
context_lens
=
context_lens
,
)
t
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
p
=
torch
.
ones
((
batch_size
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
)
num_samples
=
_MAX_NUM_SAMPLES
if
exec_mode
.
is_prefill
()
else
1
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
if
exec_mode
.
is_prefill
():
# Prefll
torch
.
_dynamo
.
mark_dynamic
(
token_ids
,
1
)
torch
.
_dynamo
.
mark_dynamic
(
position_ids
,
1
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
slot_mapping
,
1
)
else
:
# Decode
torch
.
_dynamo
.
mark_dynamic
(
token_ids
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
position_ids
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
input_lens
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
slot_mapping
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
context_lens
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
block_tables
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
t
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
p
,
0
)
# Dummy run.
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
0
):
self
.
model
(
token_ids
,
position_ids
,
input_lens
,
t
,
p
,
num_samples
,
kv_caches
)
def
warmup_model
(
self
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
None
:
# Prefill
logger
.
info
(
"Compiling the model with different input shapes..."
)
start
=
time
.
time
()
for
batch_size
in
[
1
]:
seq_len
=
16
while
seq_len
<=
self
.
model_config
.
max_model_len
:
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
exec_mode
=
ExecutionMode
.
PREFILL
)
xm
.
wait_device_ops
()
logger
.
info
(
"batch_size: %d, seq_len: %d"
,
batch_size
,
seq_len
)
num_tokens
=
batch_size
*
seq_len
if
num_tokens
>=
self
.
scheduler_config
.
max_num_batched_tokens
:
break
seq_len
=
seq_len
*
2
end
=
time
.
time
()
logger
.
info
(
"Compilation for prefill done in %.2f s."
,
end
-
start
)
# Prefix prefill
if
self
.
cache_config
.
enable_prefix_caching
:
logger
.
info
(
"Compiling the model with different input shapes for "
"prefix prefill..."
)
start
=
time
.
time
()
for
batch_size
in
[
1
]:
seq_len
=
16
while
seq_len
<=
self
.
model_config
.
max_model_len
:
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
exec_mode
=
ExecutionMode
.
PREFIX_PREFILL
)
xm
.
wait_device_ops
()
logger
.
info
(
"batch_size: %d, seq_len: %d"
,
batch_size
,
seq_len
)
num_tokens
=
batch_size
*
seq_len
if
(
num_tokens
>=
self
.
scheduler_config
.
max_num_batched_tokens
):
break
seq_len
=
seq_len
*
2
end
=
time
.
time
()
logger
.
info
(
"Compilation for prefix prefill done in %.2f s."
,
end
-
start
)
# Decode
start
=
time
.
time
()
seq_len
=
1
batch_size
=
8
# Must be in sync with _get_padded_batch_size()
while
True
:
self
.
_dummy_run
(
batch_size
,
seq_len
,
kv_caches
,
exec_mode
=
ExecutionMode
.
DECODE
)
xm
.
wait_device_ops
()
logger
.
info
(
"batch_size: %d, seq_len: %d"
,
batch_size
,
seq_len
)
if
batch_size
>=
self
.
scheduler_config
.
max_num_seqs
:
break
batch_size
=
batch_size
+
16
if
batch_size
>=
16
else
batch_size
*
2
end
=
time
.
time
()
logger
.
info
(
"Compilation for decode done in %.2f s."
,
end
-
start
)
def
_prepare_prompt
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
prompt_lens
:
List
[
int
]
=
[]
context_lens
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
for
batch_idx
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
# Could include output tokens when a request is preempted.
prompt_tokens
=
seq_data
.
get_token_ids
()
seq_len
=
len
(
prompt_tokens
)
num_computed_blocks
=
len
(
seq_group_metadata
.
computed_block_nums
)
num_computed_tokens
=
num_computed_blocks
*
self
.
block_size
if
num_computed_tokens
>
0
:
prompt_tokens
=
prompt_tokens
[
num_computed_tokens
:]
context_lens
.
append
(
seq_len
)
else
:
context_lens
.
append
(
0
)
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
input_tokens
.
extend
(
prompt_tokens
)
input_positions
.
extend
(
range
(
num_computed_tokens
,
seq_len
))
assert
seq_group_metadata
.
block_tables
is
not
None
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
for
i
in
range
(
num_computed_tokens
,
seq_len
):
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
if
num_computed_tokens
>
0
:
self
.
block_tables
[
batch_idx
,
:
len
(
block_table
)]
=
block_table
# Add paddings to EACH prompt to the smallest power of 2 that is
# greater than or equal to the prompt length.
# We pad the seq_len to reduce the compilation overhead.
# We execute each prompt individually (i.e., with batch_size 1)
# because the FlashAttention kernel does not support ragged inputs.
# TODO(woosuk): Use SplashAttention to support ragged inputs.
padded_prompt_len
=
_get_padded_prefill_len
(
prompt_len
)
num_paddings
=
padded_prompt_len
-
prompt_len
input_tokens
+=
[
0
]
*
num_paddings
input_positions
+=
[
0
]
*
num_paddings
slot_mapping
+=
[
_PAD_SLOT_ID
]
*
num_paddings
assert
len
(
prompt_lens
)
>
0
num_prefills
=
len
(
prompt_lens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int64
,
device
=
"cpu"
)
prompt_lens
=
torch
.
tensor
(
prompt_lens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
block_tables
=
torch
.
tensor
(
self
.
block_tables
[:
num_prefills
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
num_prefill_tokens
=
0
,
# NOTE: This is not used.
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
block_tables
=
block_tables
,
context_lens
=
context_lens
,
effective_query_lens
=
prompt_lens
,
)
return
input_tokens
,
input_positions
,
attn_metadata
,
prompt_lens
def
_prepare_decode
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
torch
.
Tensor
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
context_lens
:
List
[
int
]
=
[]
batch_idx
=
0
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
not
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
([
generation_token
])
seq_len
=
seq_data
.
get_len
()
position
=
seq_len
-
1
input_positions
.
append
([
position
])
context_lens
.
append
(
seq_len
)
assert
seq_group_metadata
.
block_tables
is
not
None
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
self
.
block_tables
[
batch_idx
,
:
len
(
block_table
)]
=
block_table
batch_idx
+=
1
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
([
slot
])
batch_size
=
_get_padded_batch_size
(
batch_idx
)
num_paddings
=
batch_size
-
batch_idx
input_tokens
=
input_tokens
+
[[
0
]]
*
num_paddings
input_positions
=
input_positions
+
[[
0
]]
*
num_paddings
slot_mapping
=
slot_mapping
+
[[
_PAD_SLOT_ID
]]
*
num_paddings
context_lens
=
context_lens
+
[
0
]
*
num_paddings
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int64
,
device
=
"cpu"
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
block_tables
=
torch
.
tensor
(
self
.
block_tables
[:
batch_size
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
input_lens
=
torch
.
tensor
([
1
]
*
batch_size
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
block_tables
=
block_tables
,
context_lens
=
context_lens
,
)
return
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
def
_prepare_sample
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
padded_batch_size
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
List
[
int
]]:
assert
len
(
seq_group_metadata_list
)
>
0
t
=
[]
p
=
[]
n
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
sampling_params
=
seq_group_metadata
.
sampling_params
t
.
append
(
sampling_params
.
temperature
)
if
sampling_params
.
top_p
!=
1
and
not
_ENABLE_TOP_P
:
raise
NotImplementedError
(
"Top-p sampling is currently disabled for the TPU backend "
"due to performance issues."
)
p
.
append
(
sampling_params
.
top_p
)
if
sampling_params
.
top_k
>
0
:
raise
NotImplementedError
(
"Top-k sampling is currently disabled for the TPU backend "
"due to performance issues."
)
if
sampling_params
.
n
>
_MAX_NUM_SAMPLES
:
raise
NotImplementedError
(
f
"Best of >
{
_MAX_NUM_SAMPLES
}
is not supported by the TPU "
"backend."
)
n
.
append
(
sampling_params
.
n
)
if
sampling_params
.
logprobs
is
not
None
:
raise
NotImplementedError
(
"logprobs is not currently supported by the TPU backend."
)
if
sampling_params
.
prompt_logprobs
is
not
None
:
raise
NotImplementedError
(
"prompt_logprobs is not currently supported by the TPU "
"backend."
)
# Repeat the sampling params if the seq group has multiple seqs.
num_seqs
=
len
(
seq_group_metadata
.
seq_data
)
t
+=
[
t
[
-
1
]]
*
(
num_seqs
-
1
)
p
+=
[
p
[
-
1
]]
*
(
num_seqs
-
1
)
n
+=
[
n
[
-
1
]]
*
(
num_seqs
-
1
)
num_paddings
=
padded_batch_size
-
len
(
t
)
t
+=
[
1.0
]
*
num_paddings
p
+=
[
1.0
]
*
num_paddings
t
=
torch
.
tensor
(
t
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
p
=
torch
.
tensor
(
p
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
return
t
,
p
,
n
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
)
->
ModelInputForTPU
:
del
finished_requests_ids
# Unused.
assert
virtual_engine
==
0
assert
len
(
seq_group_metadata_list
)
>
0
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
if
is_prompt
:
inputs
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
=
inputs
padded_batch_size
=
input_tokens
.
shape
[
0
]
t
,
p
,
n
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
padded_batch_size
)
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
seq_groups
=
[
list
(
metadata
.
seq_data
.
keys
())
for
metadata
in
seq_group_metadata_list
]
return
ModelInputForTPU
(
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
,
t
,
p
,
num_samples
,
n
,
seq_groups
)
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
ModelInputForTPU
:
model_input
=
ModelInputForTPU
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
self
.
attn_backend
)
return
model_input
@
torch
.
no_grad
()
def
execute_model
(
self
,
model_input
:
ModelInputForTPU
,
kv_caches
:
Optional
[
List
[
Any
]],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
)
->
List
[
SamplerOutput
]:
assert
intermediate_tensors
is
None
if
not
model_input
.
is_first_multi_step
:
if
not
model_input
.
is_last_step
:
return
[]
use_async_out_proc
=
model_input
.
async_callback
is
not
None
sampler_outputs
=
[]
num_outputs
=
len
(
self
.
cached_step_outputs
)
for
i
in
range
(
num_outputs
):
next_token_ids
=
self
.
cached_step_outputs
.
pop
(
0
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
sampler_output
=
_make_decode_output
(
next_token_ids
,
model_input
.
seq_groups
)
sampler_outputs
.
append
(
sampler_output
)
if
i
<
num_outputs
-
1
and
use_async_out_proc
:
assert
model_input
.
async_callback
is
not
None
ctx
=
model_input
.
async_callback
.
keywords
[
# type: ignore
"ctx"
]
ctx
.
append_output
(
outputs
=
[
sampler_output
],
seq_group_metadata_list
=
ctx
.
seq_group_metadata_list
,
scheduler_outputs
=
ctx
.
scheduler_outputs
,
is_async
=
False
,
is_last_step
=
False
,
is_first_step_output
=
i
==
0
)
model_input
.
async_callback
()
if
use_async_out_proc
:
return
[
sampler_outputs
[
-
1
]]
else
:
return
sampler_outputs
is_prompt
=
model_input
.
attn_metadata
.
num_prefills
>
0
if
is_prompt
:
assert
num_steps
==
1
# NOTE(woosuk): Since the FlashAttention kernel does not support
# ragged inputs, we split the prompts into different batches and
# process them separately. This is a temporary hack that should be
# optimized by using SplashAttention.
orig_slot_mapping
=
model_input
.
attn_metadata
.
slot_mapping
orig_block_tables
=
model_input
.
attn_metadata
.
block_tables
orig_context_lens
=
model_input
.
attn_metadata
.
context_lens
orig_effective_query_lens
=
\
model_input
.
attn_metadata
.
effective_query_lens
batch_size
=
model_input
.
input_lens
.
shape
[
0
]
start_idx
=
0
next_token_ids
=
[]
for
i
in
range
(
batch_size
):
# Get the actual prefill_len.
prefill_len
=
model_input
.
input_lens
[
i
:
i
+
1
].
item
()
prefill_len
=
_get_padded_prefill_len
(
prefill_len
)
end_idx
=
start_idx
+
prefill_len
token_ids
=
model_input
.
token_ids
[
None
,
start_idx
:
end_idx
].
to
(
self
.
device
)
position_ids
=
model_input
.
position_ids
[
None
,
start_idx
:
end_idx
].
to
(
self
.
device
)
attn_metadata
=
model_input
.
attn_metadata
attn_metadata
.
num_prefills
=
1
attn_metadata
.
slot_mapping
=
orig_slot_mapping
[
None
,
start_idx
:
end_idx
].
to
(
self
.
device
)
if
orig_context_lens
[
i
].
item
()
>
0
:
attn_metadata
.
context_lens
=
orig_context_lens
[
i
:
i
+
1
].
to
(
self
.
device
)
attn_metadata
.
block_tables
=
orig_block_tables
[
i
].
unsqueeze
(
0
).
to
(
self
.
device
)
attn_metadata
.
effective_query_lens
=
\
orig_effective_query_lens
[
i
:
i
+
1
].
to
(
self
.
device
)
else
:
attn_metadata
.
context_lens
=
None
attn_metadata
.
block_tables
=
None
attn_metadata
.
effective_query_lens
=
None
input_lens
=
model_input
.
input_lens
[
i
:
i
+
1
].
to
(
self
.
device
)
t
=
model_input
.
t
[
i
:
i
+
1
].
to
(
self
.
device
)
p
=
model_input
.
p
[
i
:
i
+
1
].
to
(
self
.
device
)
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
output_token_ids
=
self
.
model
(
token_ids
,
position_ids
,
input_lens
,
t
,
p
,
model_input
.
num_samples
,
kv_caches
)
next_token_ids
.
append
(
output_token_ids
[
0
])
start_idx
=
end_idx
if
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
# Retrieve the outputs to CPU.
next_token_ids
=
[
output_token_ids
.
cpu
().
tolist
()
for
output_token_ids
in
next_token_ids
]
# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
# does not support advanced sampling parameters such as logprobs.
zero_logprob
=
Logprob
(
0.0
)
sampler_outputs
=
[]
for
i
,
seq_group
in
enumerate
(
model_input
.
seq_groups
):
seq_ids
=
seq_group
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
seq_outputs
=
[]
for
j
in
range
(
model_input
.
n
[
i
]):
next_token_id
=
next_token_ids
[
i
][
j
]
seq_outputs
.
append
(
SequenceOutput
(
seq_id
,
next_token_id
,
{
next_token_id
:
zero_logprob
}))
sampler_outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
return
[
SamplerOutput
(
sampler_outputs
)]
else
:
token_ids
=
model_input
.
token_ids
.
to
(
self
.
device
)
position_ids
=
model_input
.
position_ids
.
to
(
self
.
device
)
attn_metadata
=
model_input
.
attn_metadata
attn_metadata
.
slot_mapping
=
attn_metadata
.
slot_mapping
.
to
(
self
.
device
)
attn_metadata
.
block_tables
=
attn_metadata
.
block_tables
.
to
(
self
.
device
)
attn_metadata
.
context_lens
=
attn_metadata
.
context_lens
.
to
(
self
.
device
)
t
=
model_input
.
t
.
to
(
self
.
device
)
p
=
model_input
.
p
.
to
(
self
.
device
)
input_lens
=
model_input
.
input_lens
.
to
(
self
.
device
)
for
i
in
range
(
num_steps
):
slot_mapping
=
attn_metadata
.
slot_mapping
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
output_token_ids
=
self
.
model
(
token_ids
,
position_ids
,
input_lens
,
t
,
p
,
model_input
.
num_samples
,
kv_caches
)
self
.
cached_step_outputs
.
append
(
output_token_ids
)
if
i
<
num_steps
-
1
:
# Prepare the inputs for the next step.
token_ids
=
output_token_ids
.
unsqueeze
(
dim
=
1
).
int
()
position_ids
=
position_ids
+
1
attn_metadata
.
context_lens
=
attn_metadata
.
context_lens
+
1
block_tables
=
attn_metadata
.
block_tables
block_number
=
block_tables
.
gather
(
1
,
position_ids
.
long
()
//
self
.
block_size
)
block_offset
=
position_ids
%
self
.
block_size
is_padding
=
slot_mapping
==
_PAD_SLOT_ID
slot_mapping
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
=
slot_mapping
.
long
()
slot_mapping
=
torch
.
where
(
is_padding
,
_PAD_SLOT_ID
,
slot_mapping
)
attn_metadata
.
slot_mapping
=
slot_mapping
if
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
if
num_steps
>
1
:
return
[]
# Retrieve the outputs to CPU.
next_token_ids
=
self
.
cached_step_outputs
.
pop
(
0
)
next_token_ids
=
next_token_ids
.
cpu
().
tolist
()
sampler_output
=
_make_decode_output
(
next_token_ids
,
model_input
.
seq_groups
)
return
[
sampler_output
]
class
ModelWrapper
(
nn
.
Module
):
def
__init__
(
self
,
model
:
nn
.
Module
):
super
().
__init__
()
self
.
model
=
model
def
forward
(
self
,
token_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
input_lens
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
num_samples
:
int
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
torch
.
Tensor
:
"""Executes the forward pass of the model and samples the next token.
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
num_samples: Number of samples to draw from each logits vector.
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
"""
batch_size
,
seq_len
=
token_ids
.
shape
# Calculate the positions to sample from.
start_indices
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
int32
,
device
=
input_lens
.
device
)
*
seq_len
logits_indices
=
start_indices
+
input_lens
-
1
attn_metadata
=
get_forward_context
().
attn_metadata
# FIXME(woosuk): This is a temporary hack to avoid using the existing
# sampler and sampling metadata.
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
[],
selected_token_indices
=
logits_indices
,
categorized_sample_indices
=
{},
num_prompts
=
attn_metadata
.
num_prefills
,
)
# Skip this in memory profiling at initialization.
if
kv_caches
[
0
][
0
].
numel
()
>
0
:
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
num_kv_heads
,
num_blocks
,
block_size
,
_
=
kv_caches
[
0
][
0
].
shape
slot_mapping
=
attn_metadata
.
slot_mapping
slot_mapping
=
slot_mapping
.
flatten
()
head_indices
=
torch
.
arange
(
0
,
num_kv_heads
,
device
=
slot_mapping
.
device
,
dtype
=
slot_mapping
.
dtype
)
head_indices
*=
block_size
*
num_blocks
slot_mapping
=
slot_mapping
.
repeat_interleave
(
num_kv_heads
).
view
(
-
1
,
num_kv_heads
)
slot_mapping
=
slot_mapping
+
head_indices
.
view
(
1
,
-
1
)
slot_mapping
=
slot_mapping
.
flatten
()
attn_metadata
.
slot_mapping
=
slot_mapping
hidden_states
=
self
.
model
(
token_ids
,
position_ids
)
hidden_states
=
hidden_states
.
flatten
(
0
,
1
)
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
# Argmax sampling.
argmax_token_ids
=
torch
.
argmax
(
logits
,
dim
=-
1
,
keepdim
=
True
)
argmax_token_ids
=
argmax_token_ids
.
repeat
(
1
,
num_samples
)
# Zero temperature means greedy decoding. Avoid division by zero.
nonzero_t
=
torch
.
where
(
t
!=
0
,
t
,
1.0
)
logits
=
logits
/
nonzero_t
.
unsqueeze
(
dim
=
1
)
if
_ENABLE_TOP_P
:
logits
=
_apply_top_p
(
logits
,
p
.
unsqueeze
(
dim
=
1
))
# Random sampling.
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
sampled_token_ids
=
torch
.
multinomial
(
probs
,
num_samples
,
replacement
=
True
)
if
num_samples
==
1
:
argmax_token_ids
=
argmax_token_ids
.
squeeze
(
dim
=-
1
)
sampled_token_ids
=
sampled_token_ids
.
squeeze
(
dim
=-
1
)
next_token_ids
=
torch
.
where
(
t
!=
0
,
sampled_token_ids
,
argmax_token_ids
)
return
next_token_ids
def
_get_padded_prefill_len
(
x
:
int
)
->
int
:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
# multiple of 16. This is also good for performance.
if
x
<=
16
:
return
16
return
1
<<
(
x
-
1
).
bit_length
()
def
_get_padded_batch_size
(
batch_size
:
int
)
->
int
:
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
# To meet this requirement in the simplest way, we set the minimal batch
# size to 8.
if
batch_size
<=
8
:
return
8
else
:
return
((
batch_size
+
15
)
//
16
)
*
16
def
_apply_top_p
(
logits
:
torch
.
Tensor
,
p
:
torch
.
Tensor
)
->
torch
.
Tensor
:
logits_sorted
=
torch
.
sort
(
logits
,
dim
=-
1
,
descending
=
True
).
values
sorted_cum_probs
=
torch
.
cumsum
(
logits_sorted
.
softmax
(
dim
=-
1
),
dim
=-
1
)
cutoff_index
=
torch
.
sum
(
sorted_cum_probs
<
p
,
dim
=-
1
,
keepdim
=
True
)
cutoff_logit
=
torch
.
gather
(
logits_sorted
,
-
1
,
cutoff_index
)
logits
=
logits
.
masked_fill_
(
logits
<
cutoff_logit
,
-
float
(
"inf"
))
return
logits
def
_make_decode_output
(
next_token_ids
:
List
[
int
],
seq_groups
:
List
[
List
[
int
]],
)
->
SamplerOutput
:
zero_logprob
=
Logprob
(
0.0
)
sampler_outputs
=
[]
batch_idx
=
0
for
seq_group
in
seq_groups
:
seq_ids
=
seq_group
seq_outputs
=
[]
for
seq_id
in
seq_ids
:
next_token_id
=
next_token_ids
[
batch_idx
]
seq_outputs
.
append
(
SequenceOutput
(
seq_id
,
next_token_id
,
{
next_token_id
:
zero_logprob
}))
batch_idx
+=
1
sampler_outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
None
))
return
SamplerOutput
(
sampler_outputs
)
vllm/worker/tpu_worker.py
deleted
100644 → 0
View file @
3b2aefb1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch_xla.core.xla_model
as
xm
import
torch_xla.debug.profiler
as
xp
import
torch_xla.runtime
as
xr
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
bind_kv_cache
,
get_dtype_size
from
vllm.worker.tpu_model_runner
import
ExecutionMode
,
TPUModelRunner
from
vllm.worker.worker_base
import
(
LocalOrDistributedWorkerBase
,
LoRANotSupportedWorkerBase
,
WorkerBase
,
WorkerInput
)
logger
=
init_logger
(
__name__
)
class
TPUWorker
(
LoRANotSupportedWorkerBase
,
LocalOrDistributedWorkerBase
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
is_driver_worker
:
bool
,
)
->
None
:
WorkerBase
.
__init__
(
self
,
vllm_config
=
vllm_config
)
self
.
parallel_config
.
rank
=
rank
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
is_driver_worker
=
is_driver_worker
assert
self
.
device_config
.
device_type
==
"tpu"
if
self
.
cache_config
.
cache_dtype
==
"auto"
:
self
.
cache_dtype
=
self
.
model_config
.
dtype
else
:
self
.
cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
self
.
cache_config
.
cache_dtype
]
self
.
model_runner
:
TPUModelRunner
=
TPUModelRunner
(
vllm_config
=
vllm_config
,
is_driver_worker
=
is_driver_worker
)
if
self
.
model_config
.
seed
is
None
:
self
.
model_config
.
seed
=
0
if
vllm_config
.
lora_config
is
not
None
:
raise
NotImplementedError
(
"The V0 TPU backend doesn't support LoRA serving"
)
def
init_device
(
self
)
->
None
:
os
.
environ
[
"PJRT_DEVICE"
]
=
"TPU"
torch
.
set_grad_enabled
(
False
)
torch
.
set_default_dtype
(
self
.
model_config
.
dtype
)
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
init_distributed_environment
(
world_size
=
self
.
parallel_config
.
world_size
,
rank
=
self
.
rank
,
local_rank
=
self
.
local_rank
,
distributed_init_method
=
self
.
distributed_init_method
,
backend
=
"gloo"
,
)
ensure_model_parallel_initialized
(
self
.
parallel_config
.
tensor_parallel_size
,
self
.
parallel_config
.
pipeline_parallel_size
)
# Device initialization should happen after initializing the distributed
# runtime.
self
.
device
=
xm
.
xla_device
()
self
.
device_config
.
device
=
self
.
device
# Set random seed.
set_random_seed
(
self
.
model_config
.
seed
)
xm
.
set_rng_state
(
self
.
model_config
.
seed
,
self
.
device
)
# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
# 30-40 graphs for decode. 128 is an arbitrary safe number.
torch
.
_dynamo
.
config
.
cache_size_limit
=
128
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs.
world_size
=
self
.
parallel_config
.
world_size
rank
=
xr
.
global_ordinal
()
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
# Consequently, changes in optimization flags, which affect compilation
# results, don't change the cache key. This can result in the wrong
# compilation being used. To prevent this, disabling the XLA compilation
# cache during development is recommended.We can disable it by
# `export VLLM_XLA_CACHE_PATH=`
if
envs
.
VLLM_XLA_CACHE_PATH
:
per_rank_path
=
os
.
path
.
join
(
envs
.
VLLM_XLA_CACHE_PATH
,
f
"tp
{
world_size
}
_rank
{
rank
}
"
)
xr
.
initialize_cache
(
per_rank_path
,
readonly
=
False
)
self
.
profiler
=
None
if
envs
.
VLLM_TORCH_PROFILER_DIR
and
self
.
rank
<
1
:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self
.
profile_dir
=
envs
.
VLLM_TORCH_PROFILER_DIR
logger
.
info
(
"Profiling enabled. Traces will be saved to: %s"
,
self
.
profile_dir
)
self
.
profiler
=
xp
.
start_server
(
9012
)
def
start_profile
(
self
):
if
self
.
rank
<
1
:
if
self
.
profiler
is
None
:
raise
RuntimeError
(
"Profiler is not enabled."
)
xp
.
start_trace
(
self
.
profile_dir
)
def
stop_profile
(
self
):
if
self
.
rank
<
1
:
if
self
.
profiler
is
None
:
raise
RuntimeError
(
"Profiler is not enabled."
)
xp
.
stop_trace
()
def
load_model
(
self
):
self
.
model_runner
.
load_model
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
head_size
=
self
.
model_config
.
get_head_size
()
num_kv_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches
=
[(
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
),
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
self
.
device
))
for
_
in
range
(
num_layers
)]
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
[
kv_caches
])
self
.
model_runner
.
_dummy_run
(
batch_size
=
1
,
seq_len
=
self
.
scheduler_config
.
max_num_batched_tokens
,
kv_caches
=
kv_caches
,
exec_mode
=
ExecutionMode
.
PREFILL
,
)
# Synchronize before measuring the memory usage.
xm
.
wait_device_ops
()
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
m
=
xm
.
get_memory_info
(
self
.
device
)
total_memory_size
=
m
[
"bytes_limit"
]
profiled
=
m
[
"peak_bytes_used"
]
# Weights + intermediate activations.
# Calculate the TPU KV cache size based on profiling.
usable_memory_size
=
int
(
total_memory_size
*
self
.
cache_config
.
gpu_memory_utilization
)
tpu_kv_cache_bytes
=
max
(
usable_memory_size
-
profiled
,
0
)
dtype_bytes
=
get_dtype_size
(
self
.
cache_dtype
)
block_size_bytes
=
(
dtype_bytes
*
self
.
cache_config
.
block_size
*
num_layers
*
2
*
head_size
*
num_kv_heads
)
num_tpu_blocks
=
tpu_kv_cache_bytes
//
block_size_bytes
num_tpu_blocks
=
(
num_tpu_blocks
//
8
)
*
8
# Round down to 8.
# Calculate the CPU KV cache size based on the config.
num_cpu_blocks
=
int
(
self
.
cache_config
.
swap_space_bytes
//
block_size_bytes
)
num_cpu_blocks
=
(
num_cpu_blocks
//
8
)
*
8
# Round down to 8.
return
num_tpu_blocks
,
num_cpu_blocks
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
)
->
None
:
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
self
.
block_size
=
self
.
cache_config
.
block_size
dtype
=
self
.
cache_dtype
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
num_kv_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
)
head_size
=
self
.
model_config
.
get_head_size
()
self
.
cpu_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
[]
self
.
tpu_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
[]
tpu_cache_shape
=
self
.
model_runner
.
attn_backend
.
get_kv_cache_shape
(
num_gpu_blocks
,
self
.
block_size
,
num_kv_heads
,
head_size
)
cpu_cache_shape
=
self
.
model_runner
.
attn_backend
.
get_kv_cache_shape
(
num_cpu_blocks
,
self
.
block_size
,
num_kv_heads
,
head_size
)
for
_
in
range
(
num_layers
):
tpu_k_cache
=
torch
.
zeros
(
tpu_cache_shape
,
dtype
=
dtype
,
device
=
self
.
device
)
tpu_v_cache
=
torch
.
zeros_like
(
tpu_k_cache
)
self
.
tpu_cache
.
append
((
tpu_k_cache
,
tpu_v_cache
))
cpu_k_cache
=
torch
.
zeros
(
cpu_cache_shape
,
dtype
=
dtype
,
device
=
"cpu"
)
cpu_v_cache
=
torch
.
zeros_like
(
cpu_k_cache
)
self
.
cpu_cache
.
append
((
cpu_k_cache
,
cpu_v_cache
))
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
[
self
.
tpu_cache
])
self
.
_warmup_model
()
def
_warmup_model
(
self
)
->
None
:
# FIXME(woosuk): Here we are abusing `enforce_eager` which is defined
# for CUDA graphs. We should refactor this part.
if
not
self
.
model_config
.
enforce_eager
:
# Warm up the model with all possible input shapes so that
# compilation never happens during the actual execution.
# This may take ~30 mins for the first run and ~20 mins for the
# subsequent runs.
# If `enforce_eager` is True, the ahead-of-time compilation is
# skipped and the compilation happens during the actual execution,
# which is bad for performance but useful for development.
self
.
model_runner
.
warmup_model
(
self
.
tpu_cache
)
def
get_cache_block_size_bytes
(
self
)
->
int
:
head_size
=
self
.
model_config
.
get_head_size
()
num_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
)
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
key_cache_block
=
self
.
cache_config
.
block_size
*
num_heads
*
head_size
value_cache_block
=
key_cache_block
total
=
num_layers
*
(
key_cache_block
+
value_cache_block
)
dtype_size
=
get_dtype_size
(
self
.
cache_dtype
)
return
dtype_size
*
total
@
property
def
do_metadata_broadcast
(
self
)
->
bool
:
return
self
.
parallel_config
.
tensor_parallel_size
>
1
@
property
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
# parallelism.
return
[
self
.
tpu_cache
]
def
prepare_worker_input
(
self
,
execute_model_req
:
ExecuteModelRequest
,
)
->
WorkerInput
:
virtual_engine
=
execute_model_req
.
virtual_engine
num_seq_groups
=
len
(
execute_model_req
.
seq_group_metadata_list
)
blocks_to_swap_in
=
_make_src_to_dst
(
execute_model_req
.
blocks_to_swap_in
,
"cpu"
,
self
.
device
)
blocks_to_swap_out
=
_make_src_to_dst
(
execute_model_req
.
blocks_to_swap_out
,
self
.
device
,
"cpu"
)
blocks_to_copy
=
_make_src_to_dst
(
execute_model_req
.
blocks_to_copy
,
self
.
device
,
self
.
device
)
return
WorkerInput
(
num_seq_groups
=
num_seq_groups
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
virtual_engine
=
virtual_engine
,
)
def
execute_worker
(
self
,
worker_input
:
WorkerInput
)
->
None
:
virtual_engine
=
worker_input
.
virtual_engine
assert
virtual_engine
==
0
attn_backend
=
self
.
model_runner
.
attn_backend
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
# Issue cache operations.
if
worker_input
.
blocks_to_swap_in
is
not
None
:
src_indices
,
dst_indices
=
worker_input
.
blocks_to_swap_in
if
src_indices
.
numel
()
>
0
:
# Swap from CPU to TPU.
for
i
in
range
(
num_layers
):
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
k
=
cpu_k_cache
[:,
src_indices
].
to
(
self
.
device
)
v
=
cpu_v_cache
[:,
src_indices
].
to
(
self
.
device
)
_insert_kv
(
k
,
v
,
dst_indices
,
tpu_k_cache
,
tpu_v_cache
)
if
worker_input
.
blocks_to_swap_out
is
not
None
:
src_indices
,
dst_indices
=
worker_input
.
blocks_to_swap_out
if
src_indices
.
numel
()
>
0
:
# Swap from TPU to CPU.
for
i
in
range
(
num_layers
):
tpu_k_cache
,
tpu_v_cache
=
self
.
tpu_cache
[
i
]
cpu_k_cache
,
cpu_v_cache
=
self
.
cpu_cache
[
i
]
cpu_k_cache
[:,
dst_indices
]
=
tpu_k_cache
[:,
src_indices
]
cpu_v_cache
[:,
dst_indices
]
=
tpu_v_cache
[:,
src_indices
]
if
worker_input
.
blocks_to_copy
is
not
None
:
src_indices
,
dst_indices
=
worker_input
.
blocks_to_copy
if
src_indices
.
numel
()
>
0
:
attn_backend
.
copy_blocks
(
self
.
tpu_cache
,
(
src_indices
,
dst_indices
))
def
_make_src_to_dst
(
mapping
:
List
[
Tuple
[
int
,
int
]],
src_device
:
Union
[
torch
.
device
,
str
],
dst_device
:
Union
[
torch
.
device
,
str
],
)
->
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
not
mapping
:
return
None
src_indices
=
[
i
for
i
,
_
in
mapping
]
dst_indices
=
[
i
for
_
,
i
in
mapping
]
src_indices
=
torch
.
tensor
(
src_indices
,
device
=
src_device
,
dtype
=
torch
.
int64
)
dst_indices
=
torch
.
tensor
(
dst_indices
,
device
=
dst_device
,
dtype
=
torch
.
int64
)
return
src_indices
,
dst_indices
@
torch
.
compile
(
backend
=
"openxla"
)
def
_insert_kv
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
tpu_k_cache
:
torch
.
Tensor
,
tpu_v_cache
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
tpu_k_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
tpu_v_cache
,
True
)
tpu_k_cache
[:,
indices
]
=
k
tpu_v_cache
[:,
indices
]
=
v
vllm/worker/xpu_model_runner.py
deleted
100644 → 0
View file @
3b2aefb1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
import
time
import
weakref
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
)
import
torch
import
torch.nn
as
nn
from
vllm.attention
import
get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_pp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadataCache
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalKwargs
,
MultiModalPlaceholderMap
,
MultiModalRegistry
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.utils
import
DeviceMemoryProfiler
,
GiB_bytes
,
make_tensor_with_pad
from
vllm.worker.model_runner
import
AttentionMetadata
,
SamplingMetadata
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_sampling_metadata_from_tensor_dict
)
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
logger
=
init_logger
(
__name__
)
_PAD_SLOT_ID
=
-
1
TModelInputForXPU
=
TypeVar
(
'TModelInputForXPU'
,
bound
=
"ModelInputForXPU"
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForXPU
(
ModelRunnerInputBase
):
"""
Used by the NeuronModelRunner.
"""
input_tokens
:
Optional
[
torch
.
Tensor
]
=
None
input_positions
:
Optional
[
torch
.
Tensor
]
=
None
attn_metadata
:
Optional
[
"AttentionMetadata"
]
=
None
multi_modal_kwargs
:
Optional
[
BatchedTensorInputs
]
=
None
virtual_engine
:
Optional
[
int
]
=
None
seq_lens
:
Optional
[
List
[
int
]]
=
None
query_lens
:
Optional
[
List
[
int
]]
=
None
async_callback
:
Optional
[
Callable
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
:
Type
[
TModelInputForXPU
],
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
TModelInputForXPU
:
if
attn_backend
is
not
None
:
tensor_dict
=
_init_attn_metadata_from_tensor_dict
(
attn_backend
,
tensor_dict
)
return
cls
(
**
tensor_dict
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForXPUWithSamplingMetadata
(
ModelInputForXPU
):
"""
Used by the ModelRunner.
"""
sampling_metadata
:
Optional
[
"SamplingMetadata"
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
_add_sampling_metadata_broadcastable_dict
(
tensor_dict
,
self
.
sampling_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
,
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
"ModelInputForXPUWithSamplingMetadata"
:
tensor_dict
=
_init_sampling_metadata_from_tensor_dict
(
tensor_dict
)
if
attn_backend
is
not
None
:
tensor_dict
=
_init_attn_metadata_from_tensor_dict
(
attn_backend
,
tensor_dict
)
return
cls
(
**
tensor_dict
)
class
ModelInputForXPUBuilder
(
ModelRunnerInputBuilderBase
[
ModelInputForXPU
]):
def
__init__
(
self
,
runner
:
"XPUModelRunner"
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
super
().
__init__
()
self
.
runner
=
runner
self
.
model_input_cls
=
self
.
runner
.
_model_input_cls
self
.
attn_backend
=
self
.
runner
.
attn_backend
self
.
sliding_window
=
self
.
runner
.
sliding_window
self
.
block_size
=
self
.
runner
.
block_size
self
.
device
=
self
.
runner
.
device
def
prepare
(
self
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
self
.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
self
.
seq_group_metadata_list
.
append
(
seq_group_metadata
)
def
build
(
self
)
->
ModelInputForXPU
:
is_prompt
=
self
.
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
multi_modal_kwargs
)
=
self
.
_prepare_prompt
(
self
.
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
attn_metadata
)
=
self
.
_prepare_decode
(
self
.
seq_group_metadata_list
)
seq_lens
=
None
multi_modal_kwargs
=
None
return
self
.
model_input_cls
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
attn_metadata
=
attn_metadata
,
multi_modal_kwargs
=
multi_modal_kwargs
,
seq_lens
=
seq_lens
,
query_lens
=
seq_lens
,
)
def
_prepare_prompt
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
List
[
int
],
BatchedTensorInputs
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
multi_modal_kwargs_list
:
List
[
MultiModalKwargs
]
=
[]
multi_modal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
computed_len
=
seq_data
.
get_num_computed_tokens
()
seq_len
=
len
(
prompt_tokens
)
seq_lens
.
append
(
seq_len
)
# Prompt token num
input_tokens
.
extend
(
prompt_tokens
)
# Token ids
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
positions_range
=
range
(
computed_len
,
seq_len
)
input_positions
.
extend
(
list
(
positions_range
))
if
seq_group_metadata
.
multi_modal_data
:
# NOTE: mm_kwargs only includes the subset of multi-modal items
# that intersect with the current prefill positions.
mm_kwargs
,
placeholder_maps
=
MultiModalPlaceholderMap
\
.
from_seq_group
(
seq_group_metadata
,
positions_range
)
multi_modal_kwargs_list
.
append
(
mm_kwargs
)
for
modality
,
placeholder_map
in
placeholder_maps
.
items
():
multi_modal_placeholder_maps
[
modality
].
extend
(
placeholder_map
)
if
seq_group_metadata
.
block_tables
is
None
:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping
.
extend
([
_PAD_SLOT_ID
]
*
seq_len
)
continue
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
start_idx
=
max
(
0
,
seq_len
-
self
.
sliding_window
)
for
i
in
range
(
computed_len
,
seq_len
):
if
i
<
start_idx
:
slot_mapping
.
append
(
_PAD_SLOT_ID
)
continue
block_number
=
block_table
[
i
//
self
.
block_size
]
# type: ignore
block_offset
=
i
%
self
.
block_size
# type: ignore
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
num_prompt_tokens
=
len
(
input_tokens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
multi_modal_placeholder_maps
.
items
()
}
max_seqlen
=
max
(
seq_lens
)
tmp
=
[
0
]
tmp
.
extend
(
seq_lens
)
seqlen
=
torch
.
tensor
(
tmp
)
seqlen_q
=
torch
.
cumsum
(
seqlen
,
dim
=
0
).
to
(
device
=
self
.
device
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
enable_kv_scales_calculation
=
False
,
seq_lens
=
seq_lens
,
seqlen_q
=
seqlen_q
,
max_seqlen
=
max_seqlen
,
seq_lens_tensor
=
torch
.
tensor
([]),
max_decode_seq_len
=
0
,
num_prefills
=
len
(
seq_lens
),
num_prefill_tokens
=
num_prompt_tokens
,
num_decode_tokens
=
0
,
block_tables
=
torch
.
tensor
([],
device
=
self
.
device
,
dtype
=
torch
.
int
),
)
multi_modal_kwargs
=
MultiModalKwargs
.
batch
(
multi_modal_kwargs_list
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
multi_modal_kwargs
)
def
_prepare_decode
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
not
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
token_chunk_size
==
1
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
(
generation_token
)
seq_len
=
seq_data
.
get_len
()
position
=
seq_len
-
1
input_positions
.
append
(
position
)
seq_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq_len
,
self
.
sliding_window
)
seq_lens
.
append
(
seq_len
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
if
self
.
sliding_window
is
not
None
:
sliding_window_blocks
=
(
self
.
sliding_window
//
self
.
block_size
)
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
max_decode_seq_len
=
max
(
seq_lens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
block_tables
=
make_tensor_with_pad
(
block_tables
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
self
.
device
,
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
seq_lens
=
seq_lens
,
seqlen_q
=
torch
.
tensor
([]),
max_seqlen
=
0
,
seq_lens_tensor
=
seq_lens_tensor
,
max_decode_seq_len
=
max_decode_seq_len
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
len
(
input_tokens
),
num_prefills
=
0
,
block_tables
=
block_tables
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
)
class
XPUModelRunner
(
ModelRunnerBase
[
ModelInputForXPUWithSamplingMetadata
]):
_model_input_cls
:
Type
[
ModelInputForXPUWithSamplingMetadata
]
=
(
ModelInputForXPUWithSamplingMetadata
)
_builder_cls
:
Type
[
ModelInputForXPUBuilder
]
=
ModelInputForXPUBuilder
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
return_hidden_states
:
bool
=
False
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
):
ModelRunnerBase
.
__init__
(
self
,
vllm_config
=
vllm_config
)
model_config
=
self
.
model_config
cache_config
=
self
.
cache_config
self
.
is_driver_worker
=
is_driver_worker
self
.
return_hidden_states
=
return_hidden_states
self
.
device
=
self
.
device_config
.
device
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
)
# Multi-modal data support
self
.
input_registry
=
input_registry
self
.
mm_registry
=
mm_registry
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
self
.
sampler
=
get_sampler
()
self
.
sampling_metadata_cache
:
SamplingMetadataCache
=
\
SamplingMetadataCache
()
\
if
self
.
parallel_config
.
pipeline_parallel_size
==
1
else
None
self
.
builder
=
self
.
_builder_cls
(
weakref
.
proxy
(
self
))
def
load_model
(
self
)
->
None
:
with
DeviceMemoryProfiler
()
as
m
:
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
self
.
model_memory_usage
=
m
.
consumed_memory
logger
.
info
(
"Loading model weights took %.4f GiB"
,
self
.
model_memory_usage
/
GiB_bytes
)
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model
@
property
def
vocab_size
(
self
)
->
int
:
return
self
.
model_config
.
get_vocab_size
()
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params
=
SamplingParams
(
top_p
=
0.99
,
top_k
=
self
.
vocab_size
-
1
)
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
# Additional GPU memory may be needed for multi-modal encoding, which
# needs to be accounted for when calculating the GPU blocks for
# vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
max_mm_tokens
=
self
.
mm_registry
.
get_max_multimodal_tokens
(
self
.
model_config
)
if
max_mm_tokens
>
0
:
max_num_seqs_orig
=
max_num_seqs
max_num_seqs
=
min
(
max_num_seqs
,
max_num_batched_tokens
//
max_mm_tokens
)
if
max_num_seqs
<
1
:
expr
=
(
f
"min(
{
max_num_seqs_orig
}
, "
f
"
{
max_num_batched_tokens
}
//
{
max_mm_tokens
}
)"
)
logger
.
warning
(
"Computed max_num_seqs (%s) to be less than 1. "
"Setting it to the minimum value of 1."
,
expr
)
max_num_seqs
=
1
batch_size
=
0
for
group_id
in
range
(
max_num_seqs
):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
batch_size
+=
seq_len
dummy_data
=
self
.
input_registry
\
.
dummy_data_for_profiling
(
self
.
model_config
,
seq_len
,
self
.
mm_registry
)
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
is_prompt
=
True
,
seq_data
=
{
group_id
:
dummy_data
.
seq_data
},
sampling_params
=
sampling_params
,
block_tables
=
None
,
lora_request
=
None
,
multi_modal_data
=
dummy_data
.
multi_modal_data
,
multi_modal_placeholders
=
dummy_data
.
multi_modal_placeholders
)
seqs
.
append
(
seq
)
finished_requests_ids
=
[
seq
.
request_id
for
seq
in
seqs
]
model_input
=
self
.
prepare_model_input
(
seqs
,
finished_requests_ids
=
finished_requests_ids
)
intermediate_tensors
=
None
if
not
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
self
.
model
.
make_empty_intermediate_tensors
(
batch_size
=
batch_size
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
)
self
.
execute_model
(
model_input
,
None
,
intermediate_tensors
)
torch
.
xpu
.
synchronize
()
return
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
ModelInputForXPUWithSamplingMetadata
:
return
(
ModelInputForXPUWithSamplingMetadata
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
self
.
attn_backend
,
))
def
_prepare_model_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
ModelInputForXPUWithSamplingMetadata
:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
"""
builder
=
self
.
builder
builder
.
prepare
(
finished_requests_ids
)
for
seq_group_metadata
in
seq_group_metadata_list
:
builder
.
add_seq_group
(
seq_group_metadata
)
return
builder
.
build
()
# type: ignore
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
ModelInputForXPUWithSamplingMetadata
:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
"""
model_input
=
self
.
_prepare_model_input_tensors
(
seq_group_metadata_list
,
finished_requests_ids
)
# Sampling metadata is only required for the final pp group
generators
=
self
.
get_generators
(
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
model_input
.
seq_lens
,
model_input
.
query_lens
,
self
.
device
,
pin_memory
=
False
,
generators
=
generators
,
cache
=
self
.
sampling_metadata_cache
)
return
dataclasses
.
replace
(
model_input
,
sampling_metadata
=
sampling_metadata
,
virtual_engine
=
virtual_engine
)
@
torch
.
inference_mode
()
def
execute_model
(
self
,
model_input
:
ModelInputForXPUWithSamplingMetadata
,
kv_caches
:
List
[
torch
.
Tensor
],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
SamplerOutput
]]:
if
num_steps
>
1
:
raise
ValueError
(
"XPUModelRunner does not support multi-step execution."
)
model_executable
=
self
.
model
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
):
model_forward_start_time
=
time
.
time
()
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
device
=
self
.
device
,
),
)
# Compute the logits in the last pipeline stage.
if
not
get_pp_group
().
is_last_rank
:
return
hidden_or_intermediate_states
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
):
model_forward_end_time
=
time
.
time
()
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_or_intermediate_states
,
model_input
.
sampling_metadata
)
# Only perform sampling in the driver worker.
if
not
self
.
is_driver_worker
:
return
[]
if
model_input
.
async_callback
is
not
None
:
model_input
.
async_callback
()
# Sample the next token.
output
:
SamplerOutput
=
self
.
sampler
(
logits
=
logits
,
sampling_metadata
=
model_input
.
sampling_metadata
,
)
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
and
output
is
not
None
):
model_forward_time
=
(
model_forward_end_time
-
model_forward_start_time
)
# If there are multiple workers, we are still tracking the latency
# from the start time of the driver worker to the end time of the
# driver worker. The model forward time will then end up covering
# the communication time as well.
output
.
model_forward_time
=
model_forward_time
return
[
output
]
vllm/worker/xpu_worker.py
deleted
100644 → 0
View file @
3b2aefb1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A XPU worker class."""
import
gc
import
os
from
typing
import
List
,
Optional
,
Tuple
import
intel_extension_for_pytorch
# noqa: F401
import
oneccl_bindings_for_pytorch
# noqa: F401
import
torch
import
torch.distributed
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.platforms
import
current_platform
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker_base
import
LoRANotSupportedWorkerBase
,
WorkerBase
from
vllm.worker.xpu_model_runner
import
XPUModelRunner
logger
=
init_logger
(
__name__
)
class
XPUWorker
(
LoRANotSupportedWorkerBase
,
Worker
):
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single XPU device. The worker is
responsible for maintaining the KV cache and executing the model on the
XPU. In case of distributed inference, each worker is assigned a partition
of the model.
"""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
is_driver_worker
:
bool
=
False
,
)
->
None
:
WorkerBase
.
__init__
(
self
,
vllm_config
=
vllm_config
)
device_config
=
self
.
device_config
parallel_config
=
self
.
parallel_config
assert
device_config
.
device_type
==
"xpu"
assert
current_platform
.
is_xpu
()
self
.
parallel_config
.
rank
=
rank
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
is_driver_worker
=
is_driver_worker
if
parallel_config
and
is_driver_worker
:
assert
rank
%
parallel_config
.
tensor_parallel_size
==
0
,
\
"Driver worker should be rank 0 of tensor parallel group."
self
.
model_runner
=
XPUModelRunner
(
# type: ignore
vllm_config
=
vllm_config
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
is_driver_worker
=
is_driver_worker
,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self
.
cache_engine
:
List
[
CacheEngine
]
self
.
gpu_cache
:
Optional
[
List
[
List
[
torch
.
Tensor
]]]
def
init_device
(
self
)
->
None
:
if
self
.
device_config
.
device
.
type
==
"xpu"
and
current_platform
.
is_xpu
(
):
self
.
device
=
torch
.
device
(
f
"xpu:
{
self
.
local_rank
}
"
)
torch
.
xpu
.
set_device
(
self
.
device
)
torch
.
xpu
.
empty_cache
()
self
.
init_gpu_memory
=
torch
.
xpu
.
get_device_properties
(
self
.
local_rank
).
total_memory
else
:
raise
RuntimeError
(
f
"Not support device type:
{
self
.
device_config
.
device
}
"
)
# Initialize the distributed environment.
self
.
init_worker_distributed_environment
()
# Initialize the model.
set_random_seed
(
self
.
model_config
.
seed
)
# keep this method for `empty_cache` and `synchronize` api
@
torch
.
inference_mode
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch
.
xpu
.
empty_cache
()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self
.
model_runner
.
profile_run
()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch
.
xpu
.
synchronize
()
used_memory
=
torch
.
xpu
.
memory_allocated
()
total_gpu_memory
=
torch
.
xpu
.
get_device_properties
(
self
.
local_rank
).
total_memory
free_gpu_memory
=
total_gpu_memory
-
used_memory
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory
=
self
.
init_gpu_memory
-
free_gpu_memory
assert
peak_memory
>
0
,
(
"Error in memory profiling. "
f
"Initial free memory
{
self
.
init_gpu_memory
}
, current free memory"
f
"
{
free_gpu_memory
}
. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance."
)
cache_block_size
=
self
.
get_cache_block_size_bytes
()
num_gpu_blocks
=
int
(
(
total_gpu_memory
*
self
.
cache_config
.
gpu_memory_utilization
-
peak_memory
)
//
cache_block_size
)
num_cpu_blocks
=
int
(
self
.
cache_config
.
swap_space_bytes
//
cache_block_size
)
num_gpu_blocks
=
max
(
num_gpu_blocks
,
0
)
num_cpu_blocks
=
max
(
num_cpu_blocks
,
0
)
gc
.
collect
()
torch
.
xpu
.
empty_cache
()
return
num_gpu_blocks
,
num_cpu_blocks
def
_warm_up_model
(
self
)
->
None
:
# IPEX don't support capture graph yet
pass
def
init_worker_distributed_environment
(
self
)
->
None
:
"""Initialize the distributed environment."""
parallel_config
=
self
.
parallel_config
rank
=
self
.
rank
distributed_init_method
=
self
.
distributed_init_method
if
torch
.
distributed
.
is_initialized
():
torch_world_size
=
torch
.
distributed
.
get_world_size
()
if
torch_world_size
!=
parallel_config
.
world_size
:
raise
RuntimeError
(
"torch.distributed is already initialized but the torch "
"world size does not match parallel_config.world_size "
f
"(
{
torch_world_size
}
vs.
{
parallel_config
.
world_size
}
)."
)
elif
not
distributed_init_method
:
raise
ValueError
(
"distributed_init_method must be set if torch.distributed "
"is not already initialized"
)
else
:
# use sockets as default Level zero IPC exchange backend. By
# default oneccl will use `drmfd` as mechanism which need extra
# dependency (libdrm and drm headers) on your system.
ENV_CCL_ATL_TRANSPORT
=
os
.
getenv
(
"CCL_ATL_TRANSPORT"
,
"ofi"
)
ENV_LOCAL_WORLD_SIZE
=
os
.
getenv
(
"LOCAL_WORLD_SIZE"
,
str
(
parallel_config
.
world_size
))
os
.
environ
[
"CCL_ATL_TRANSPORT"
]
=
ENV_CCL_ATL_TRANSPORT
os
.
environ
[
"LOCAL_WORLD_SIZE"
]
=
ENV_LOCAL_WORLD_SIZE
os
.
environ
[
"LOCAL_RANK"
]
=
str
(
self
.
local_rank
)
init_distributed_environment
(
world_size
=
parallel_config
.
world_size
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
local_rank
=
self
.
local_rank
,
backend
=
"ccl"
)
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
# global all_reduce needed for overall oneccl warm up
torch
.
distributed
.
all_reduce
(
torch
.
zeros
(
1
).
xpu
())
if
parallel_config
.
pipeline_parallel_size
>
1
:
# Add pp group init to avoid
# p2p communication as the first call
get_pp_group
().
all_reduce
(
torch
.
zeros
(
1
).
xpu
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment