Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fc5ebbd1
Unverified
Commit
fc5ebbd1
authored
Aug 23, 2024
by
Kunshang Ji
Committed by
GitHub
Aug 22, 2024
Browse files
[Hardware][Intel GPU] refactor xpu_model_runner for tp (#7712)
parent
c01a6cb2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
370 additions
and
650 deletions
+370
-650
vllm/executor/ray_xpu_executor.py
vllm/executor/ray_xpu_executor.py
+17
-366
vllm/worker/xpu_model_runner.py
vllm/worker/xpu_model_runner.py
+350
-278
vllm/worker/xpu_worker.py
vllm/worker/xpu_worker.py
+3
-6
No files found.
vllm/executor/ray_xpu_executor.py
View file @
fc5ebbd1
import
asyncio
import
os
from
collections
import
defaultdict
from
itertools
import
islice
,
repeat
from
typing
import
(
TYPE_CHECKING
,
Any
,
Awaitable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
)
from
typing
import
List
,
Optional
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.executor.distributed_gpu_executor
import
(
# yapf: disable
DistributedGPUExecutor
,
DistributedGPUExecutorAsync
)
from
vllm.executor.ray_utils
import
RayWorkerWrapper
,
ray
from
vllm.executor.ray_gpu_executor
import
RayGPUExecutor
,
RayGPUExecutorAsync
from
vllm.executor.xpu_executor
import
XPUExecutor
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
if
ray
is
not
None
:
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
from
vllm.utils
import
get_vllm_instance_id
,
make_async
logger
=
init_logger
(
__name__
)
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG
=
envs
.
VLLM_USE_RAY_COMPILED_DAG
class
RayXPUExecutor
(
DistributedGPUExecutor
):
uses_ray
:
bool
=
True
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
)
->
None
:
assert
device_config
.
device_type
==
"xpu"
assert
(
not
speculative_config
),
"Speculative decoding not yet supported for XPU backend"
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
load_config
=
load_config
self
.
lora_config
=
lora_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
prompt_adapter_config
=
prompt_adapter_config
placement_group
=
self
.
parallel_config
.
placement_group
# Disable Ray usage stats collection.
ray_usage
=
os
.
environ
.
get
(
"RAY_USAGE_STATS_ENABLED"
,
"0"
)
if
ray_usage
!=
"1"
:
os
.
environ
[
"RAY_USAGE_STATS_ENABLED"
]
=
"0"
# Create the parallel GPU workers.
self
.
_init_workers_ray
(
placement_group
)
self
.
forward_dag
=
None
if
USE_RAY_COMPILED_DAG
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
False
)
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
self
.
parallel_worker_tasks
:
Optional
[
Union
[
Any
,
Awaitable
[
Any
]]]
=
None
# Updated by implementations that require additional args to be passed
# to the _run_workers execute_model call
self
.
extra_execute_model_run_workers_kwargs
:
Dict
[
str
,
Any
]
=
{}
def
_init_executor
(
self
)
->
None
:
pass
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
Returns:
- Tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks
=
self
.
_run_workers
(
"determine_num_available_blocks"
,
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks
=
min
(
b
[
0
]
for
b
in
num_blocks
)
num_cpu_blocks
=
min
(
b
[
1
]
for
b
in
num_blocks
)
return
num_gpu_blocks
,
num_cpu_blocks
def
_get_worker_wrapper_args
(
self
)
->
Dict
[
str
,
Any
]:
return
dict
(
worker_module_name
=
"vllm.worker.xpu_worker"
,
worker_class_name
=
"XPUWorker"
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
if
self
.
parallel_config
.
tensor_parallel_size
==
1
:
# For single GPU case, we use a ray worker with constrained memory.
num_gpus
=
self
.
cache_config
.
gpu_memory_utilization
else
:
# Otherwise, the ray workers are allocated with a full GPU.
num_gpus
=
1
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self
.
driver_dummy_worker
:
Optional
[
RayWorkerWrapper
]
=
None
# The remaining workers are the actual ray actors.
self
.
workers
:
List
[
RayWorkerWrapper
]
=
[]
# Create the workers.
driver_ip
=
get_ip
()
worker_wrapper_kwargs
=
self
.
_get_worker_wrapper_args
()
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
if
not
bundle
.
get
(
"GPU"
,
0
):
continue
scheduling_strategy
=
PlacementGroupSchedulingStrategy
(
placement_group
=
placement_group
,
placement_group_capture_child_tasks
=
True
,
placement_group_bundle_index
=
bundle_id
,
)
worker
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
num_gpus
,
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
)(
RayWorkerWrapper
).
remote
(
**
worker_wrapper_kwargs
)
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
self
.
driver_worker
=
RayWorkerWrapper
(
**
worker_wrapper_kwargs
)
else
:
# Else, added to the list of workers.
self
.
workers
.
append
(
worker
)
if
self
.
driver_dummy_worker
is
None
:
raise
ValueError
(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node."
)
class
RayXPUExecutor
(
RayGPUExecutor
,
XPUExecutor
):
def
_get_env_vars_to_be_updated
(
self
):
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids
=
self
.
_run_workers
(
"get_node_and_gpu_ids"
,
use_dummy_driver
=
True
)
node_workers
=
defaultdict
(
list
)
node_gpus
=
defaultdict
(
list
)
for
i
,
(
node_id
,
gpu_ids
)
in
enumerate
(
worker_node_and_gpu_ids
):
node_workers
[
node_id
].
append
(
i
)
node_gpus
[
node_id
].
extend
(
gpu_ids
)
for
node_id
,
gpu_ids
in
node_gpus
.
items
():
node_gpus
[
node_id
]
=
sorted
(
gpu_ids
)
# TODO: add env var for xpu
distributed_init_method
=
get_distributed_init_method
(
driver_ip
,
get_open_port
())
def
collect_arg_helper_func
(
**
kwargs
):
# avoid writing `{"name": value}` manually
return
kwargs
init_worker_all_kwargs
=
[]
# Initialize the actual workers inside worker wrapper.
for
rank
,
(
node_id
,
_
)
in
enumerate
(
worker_node_and_gpu_ids
,
):
local_rank
=
node_workers
[
node_id
].
index
(
rank
)
init_worker_all_kwargs
.
append
(
collect_arg_helper_func
(
model_config
=
self
.
model_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
,
device_config
=
self
.
device_config
,
cache_config
=
self
.
cache_config
,
load_config
=
self
.
load_config
,
local_rank
=
local_rank
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
lora_config
=
self
.
lora_config
,
is_driver_worker
=
rank
==
0
,
))
self
.
_run_workers
(
"init_worker"
,
all_kwargs
=
init_worker_all_kwargs
)
VLLM_INSTANCE_ID
=
get_vllm_instance_id
()
self
.
_run_workers
(
"init_device"
)
self
.
_run_workers
(
"load_model"
,
max_concurrent_workers
=
self
.
parallel_config
.
max_parallel_loading_workers
,
)
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables
=
[({
"VLLM_INSTANCE_ID"
:
VLLM_INSTANCE_ID
,
"VLLM_TRACE_FUNCTION"
:
str
(
envs
.
VLLM_TRACE_FUNCTION
),
},
)
for
(
_
,
_
)
in
worker_node_and_gpu_ids
]
return
all_args_to_update_environment_variables
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""Initialize the KV cache in all workers.
"""
# NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors
# have GPUs.
logger
.
info
(
"# GPU blocks: %d, "
"# CPU blocks: %d"
,
num_gpu_blocks
,
num_cpu_blocks
)
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
self
.
_run_workers
(
"initialize_cache"
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
)
def
_driver_execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return
self
.
driver_worker
.
execute_method
(
"execute_model"
,
execute_model_req
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
assert
lora_request
.
lora_int_id
>
0
,
"lora_id must be greater than 0."
return
self
.
_run_workers
(
"add_lora"
,
lora_request
=
lora_request
,
)
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
assert
lora_id
>
0
,
"lora_id must be greater than 0."
return
self
.
_run_workers
(
"remove_lora"
,
lora_id
=
lora_id
,
)
def
list_loras
(
self
)
->
Set
[
int
]:
return
self
.
_run_workers
(
"list_loras"
)
def
_run_workers
(
self
,
method
:
str
,
*
args
,
async_run_remote_workers_only
:
bool
=
False
,
all_args
:
Optional
[
List
[
Tuple
[
Any
,
...]]]
=
None
,
all_kwargs
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
use_dummy_driver
:
bool
=
False
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers. Can be used in the following
ways:
- args/kwargs: All workers share the same args/kwargs
- args/kwargs and driver_args/driver_kwargs: Driver worker has
different args
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
if
max_concurrent_workers
:
raise
NotImplementedError
(
"max_concurrent_workers is not supported yet."
)
count
=
len
(
self
.
workers
)
all_worker_args
=
repeat
(
args
,
count
)
if
all_args
is
None
\
else
islice
(
all_args
,
1
,
None
)
all_worker_kwargs
=
repeat
(
kwargs
,
count
)
if
all_kwargs
is
None
\
else
islice
(
all_kwargs
,
1
,
None
)
# Start the ray workers first.
ray_worker_outputs
=
[
worker
.
execute_method
.
remote
(
method
,
*
worker_args
,
**
worker_kwargs
)
for
(
worker
,
worker_args
,
worker_kwargs
)
in
zip
(
self
.
workers
,
all_worker_args
,
all_worker_kwargs
)
]
if
async_run_remote_workers_only
:
# Just return futures
return
ray_worker_outputs
driver_worker_output
=
[]
driver_args
=
args
if
all_args
is
None
else
all_args
[
0
]
driver_kwargs
=
kwargs
if
all_kwargs
is
None
else
all_kwargs
[
0
]
# Start the driver worker after all the ray workers.
if
not
use_dummy_driver
:
driver_worker_output
=
self
.
driver_worker
.
execute_method
(
method
,
*
driver_args
,
**
driver_kwargs
)
else
:
assert
self
.
driver_dummy_worker
is
not
None
driver_worker_output
=
ray
.
get
(
self
.
driver_dummy_worker
.
execute_method
.
remote
(
method
,
*
driver_args
,
**
driver_kwargs
))
# Get the results of the ray workers.
if
self
.
workers
:
ray_worker_outputs
=
ray
.
get
(
ray_worker_outputs
)
return
driver_worker_output
+
ray_worker_outputs
def
_wait_for_tasks_completion
(
self
,
parallel_worker_tasks
:
Any
)
->
None
:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray
.
get
(
parallel_worker_tasks
)
def
_compiled_ray_dag
(
self
,
enable_asyncio
:
bool
):
import
pkg_resources
from
packaging
import
version
required_version
=
version
.
parse
(
"2.32"
)
current_version
=
version
.
parse
(
pkg_resources
.
get_distribution
(
"ray"
).
version
)
if
current_version
<
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
or greater is "
f
"required, but found
{
current_version
}
"
)
from
ray.dag
import
InputNode
,
MultiOutputNode
assert
self
.
parallel_config
.
use_ray
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
with
InputNode
()
as
input_data
:
forward_dag
=
MultiOutputNode
([
worker
.
execute_model_compiled_dag_remote
.
bind
(
# type: ignore[attr-defined]
input_data
)
for
worker
in
self
.
workers
])
return
forward_dag
.
experimental_compile
(
enable_asyncio
=
enable_asyncio
)
def
check_health
(
self
)
->
None
:
"""Raises an error if engine is unhealthy."""
self
.
_check_if_any_actor_is_dead
()
def
_check_if_any_actor_is_dead
(
self
):
if
not
self
.
workers
:
return
dead_actors
=
[]
for
actor
in
self
.
workers
:
actor_state
=
ray
.
state
.
actors
(
actor
.
_ray_actor_id
.
hex
())
# pylint: disable=protected-access
if
actor_state
[
"State"
]
==
"DEAD"
:
dead_actors
.
append
(
actor
)
if
dead_actors
:
raise
RuntimeError
(
"At least one Worker is dead. "
f
"Dead Workers:
{
dead_actors
}
. "
)
class
RayXPUExecutorAsync
(
RayXPUExecutor
,
DistributedGPUExecutorAsync
):
class
RayXPUExecutorAsync
(
RayXPUExecutor
,
RayGPUExecutorAsync
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
driver_exec_method
=
make_async
(
self
.
driver_worker
.
execute_method
)
async
def
_driver_execute_model_async
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
return
await
self
.
driver_exec_method
(
"execute_model"
,
execute_model_req
)
async
def
_start_worker_execution_loop
(
self
):
coros
=
[
worker
.
execute_method
.
remote
(
"start_worker_execution_loop"
)
for
worker
in
self
.
workers
]
return
await
asyncio
.
gather
(
*
coros
)
self
.
pp_locks
:
Optional
[
List
[
asyncio
.
Lock
]]
=
None
vllm/worker/xpu_model_runner.py
View file @
fc5ebbd1
import
dataclasses
import
time
import
weakref
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
)
import
torch
import
torch.nn
as
nn
from
vllm.attention
import
get_attn_backend
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModal
Config
,
ParallelConfig
,
ModelConfig
,
Observability
Config
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
...
...
@@ -20,7 +23,7 @@ from vllm.sequence import (IntermediateTensors, SamplerOutput,
from
vllm.utils
import
CudaMemoryProfiler
,
make_tensor_with_pad
from
vllm.worker.model_runner
import
AttentionMetadata
,
SamplingMetadata
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
,
...
...
@@ -37,6 +40,8 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT
*
i
for
i
in
range
(
1
,
33
)
]
TModelInputForXPU
=
TypeVar
(
'TModelInputForXPU'
,
bound
=
"ModelInputForXPU"
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForXPU
(
ModelRunnerInputBase
):
...
...
@@ -46,11 +51,40 @@ class ModelInputForXPU(ModelRunnerInputBase):
input_tokens
:
Optional
[
torch
.
Tensor
]
=
None
input_positions
:
Optional
[
torch
.
Tensor
]
=
None
attn_metadata
:
Optional
[
"AttentionMetadata"
]
=
None
sampling_metadata
:
Optional
[
"SamplingMetadata"
]
=
None
multi_modal_kwargs
:
Optional
[
BatchedTensorInputs
]
=
None
virtual_engine
:
Optional
[
int
]
=
None
seq_lens
:
Optional
[
List
[
int
]]
=
None
query_lens
:
Optional
[
List
[
int
]]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
}
_add_attn_metadata_broadcastable_dict
(
tensor_dict
,
self
.
attn_metadata
)
return
tensor_dict
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
:
Type
[
TModelInputForXPU
],
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
TModelInputForXPU
:
if
attn_backend
is
not
None
:
tensor_dict
=
_init_attn_metadata_from_tensor_dict
(
attn_backend
,
tensor_dict
)
return
cls
(
**
tensor_dict
)
@
dataclass
(
frozen
=
True
)
class
ModelInputForXPUWithSamplingMetadata
(
ModelInputForXPU
):
"""
Used by the ModelRunner.
"""
sampling_metadata
:
Optional
[
"SamplingMetadata"
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Union
[
int
,
torch
.
Tensor
]]:
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
"input_tokens"
:
self
.
input_tokens
,
"input_positions"
:
self
.
input_positions
,
...
...
@@ -62,10 +96,10 @@ class ModelInputForXPU(ModelRunnerInputBase):
@
classmethod
def
from_broadcasted_tensor_dict
(
cls
:
Type
[
"ModelInputForXPU"
]
,
cls
,
tensor_dict
:
Dict
[
str
,
Any
],
attn_backend
:
Optional
[
"AttentionBackend"
]
=
None
,
)
->
"ModelInputForXPU"
:
)
->
"ModelInputForXPU
WithSamplingMetadata
"
:
tensor_dict
=
_init_sampling_metadata_from_tensor_dict
(
tensor_dict
)
if
attn_backend
is
not
None
:
tensor_dict
=
_init_attn_metadata_from_tensor_dict
(
...
...
@@ -73,7 +107,230 @@ class ModelInputForXPU(ModelRunnerInputBase):
return
cls
(
**
tensor_dict
)
class
XPUModelRunner
(
ModelRunnerBase
[
ModelInputForXPU
]):
class
ModelInputForXPUBuilder
(
ModelRunnerInputBuilderBase
[
ModelInputForXPU
]):
def
__init__
(
self
,
runner
:
"XPUModelRunner"
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
None
:
super
().
__init__
()
self
.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
self
.
runner
=
runner
self
.
model_input_cls
=
self
.
runner
.
_model_input_cls
self
.
attn_backend
=
self
.
runner
.
attn_backend
self
.
sliding_window
=
self
.
runner
.
sliding_window
self
.
block_size
=
self
.
runner
.
block_size
self
.
device
=
self
.
runner
.
device
def
add_seq_group
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
):
self
.
seq_group_metadata_list
.
append
(
seq_group_metadata
)
def
build
(
self
)
->
ModelInputForXPU
:
is_prompt
=
self
.
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
multi_modal_kwargs
)
=
self
.
_prepare_prompt
(
self
.
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
attn_metadata
)
=
self
.
_prepare_decode
(
self
.
seq_group_metadata_list
)
seq_lens
=
[]
multi_modal_kwargs
=
None
return
self
.
model_input_cls
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
attn_metadata
=
attn_metadata
,
multi_modal_kwargs
=
multi_modal_kwargs
,
seq_lens
=
seq_lens
,
query_lens
=
seq_lens
,
)
def
_prepare_prompt
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
List
[
int
],
BatchedTensorInputs
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
multi_modal_inputs_list
:
List
[
MultiModalInputs
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
computed_len
=
seq_data
.
get_num_computed_tokens
()
seq_len
=
len
(
prompt_tokens
)
seq_lens
.
append
(
seq_len
)
# Prompt token num
input_tokens
.
extend
(
prompt_tokens
)
# Token ids
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions
.
extend
(
list
(
range
(
computed_len
,
seq_len
)))
if
seq_group_metadata
.
block_tables
is
None
:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping
.
extend
([
_PAD_SLOT_ID
]
*
seq_len
)
continue
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
start_idx
=
max
(
0
,
seq_len
-
self
.
sliding_window
)
for
i
in
range
(
computed_len
,
seq_len
):
if
i
<
start_idx
:
slot_mapping
.
append
(
_PAD_SLOT_ID
)
continue
block_number
=
block_table
[
i
//
self
.
block_size
]
# type: ignore
block_offset
=
i
%
self
.
block_size
# type: ignore
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
num_prompt_tokens
=
len
(
input_tokens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
max_seqlen
=
max
(
seq_lens
)
tmp
=
[
0
]
tmp
.
extend
(
seq_lens
)
seqlen
=
torch
.
tensor
(
tmp
)
seqlen_q
=
torch
.
cumsum
(
seqlen
,
dim
=
0
).
to
(
device
=
self
.
device
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
slot_mapping
=
slot_mapping
,
seq_lens
=
seq_lens
,
seqlen_q
=
seqlen_q
,
max_seqlen
=
max_seqlen
,
seq_lens_tensor
=
torch
.
tensor
([]),
max_decode_seq_len
=
0
,
num_prefills
=
len
(
seq_lens
),
num_prefill_tokens
=
num_prompt_tokens
,
num_decode_tokens
=
0
,
block_tables
=
torch
.
tensor
([],
device
=
self
.
device
,
dtype
=
torch
.
int
),
)
multi_modal_kwargs
=
MultiModalInputs
.
batch
(
multi_modal_inputs_list
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
multi_modal_kwargs
)
def
_prepare_decode
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
not
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
token_chunk_size
==
1
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
(
generation_token
)
seq_len
=
seq_data
.
get_len
()
position
=
seq_len
-
1
input_positions
.
append
(
position
)
seq_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq_len
,
self
.
sliding_window
)
seq_lens
.
append
(
seq_len
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
if
self
.
sliding_window
is
not
None
:
sliding_window_blocks
=
(
self
.
sliding_window
//
self
.
block_size
)
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
max_decode_seq_len
=
max
(
seq_lens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
block_tables
=
make_tensor_with_pad
(
block_tables
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
self
.
device
,
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
slot_mapping
=
slot_mapping
,
seq_lens
=
seq_lens
,
seqlen_q
=
torch
.
tensor
([]),
max_seqlen
=
0
,
seq_lens_tensor
=
seq_lens_tensor
,
max_decode_seq_len
=
max_decode_seq_len
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
len
(
input_tokens
),
num_prefills
=
0
,
block_tables
=
block_tables
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
)
class
XPUModelRunner
(
ModelRunnerBase
[
ModelInputForXPUWithSamplingMetadata
]):
_model_input_cls
:
Type
[
ModelInputForXPUWithSamplingMetadata
]
=
(
ModelInputForXPUWithSamplingMetadata
)
_builder_cls
:
Type
[
ModelInputForXPUBuilder
]
=
ModelInputForXPUBuilder
def
__init__
(
self
,
...
...
@@ -84,30 +341,32 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
is_driver_worker
:
bool
=
False
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
return_hidden_states
:
bool
=
False
,
observability_config
:
Optional
[
ObservabilityConfig
]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
*
args
,
**
kwargs
,
):
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
load_config
=
load_config
self
.
cache_config
=
cache_config
self
.
prompt_adapter_config
=
prompt_adapter_config
self
.
multimodal_config
=
multimodal_config
self
.
is_driver_worker
=
is_driver_worker
self
.
prompt_adapter_config
=
prompt_adapter_config
self
.
observability_config
=
observability_config
if
self
.
observability_config
is
not
None
:
print
(
f
"observability_config is
{
self
.
observability_config
}
"
)
self
.
return_hidden_states
=
return_hidden_states
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
device_config
=
device_config
self
.
device
=
self
.
device_config
.
device
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
self
.
attn_backend
=
get_attn_backend
(
...
...
@@ -203,166 +462,68 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# Run the model with the dummy inputs.
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
kv_caches
=
[
None
]
*
num_layers
model_input
=
self
.
prepare_model_input
(
seqs
)
finished_requests_ids
=
[
seq
.
request_id
for
seq
in
seqs
]
model_input
=
self
.
prepare_model_input
(
seqs
,
finished_requests_ids
=
finished_requests_ids
)
self
.
execute_model
(
model_input
,
kv_caches
)
torch
.
xpu
.
synchronize
()
return
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
ModelInputForXPU
:
return
(
ModelInputForXPU
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
self
.
attn_backend
,
))
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
ModelInputForXPU
:
multi_modal_kwargs
=
None
if
self
.
is_driver_worker
:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
multi_modal_kwargs
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
attn_metadata
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
seq_lens
=
[]
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
# subquery_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens
,
self
.
device
,
pin_memory
=
False
,
generators
=
self
.
get_generators
(
finished_requests_ids
))
# Broadcast the metadata.
metadata_dict
=
{
"input_tokens"
:
input_tokens
,
"input_positions"
:
input_positions
,
"selected_token_indices"
:
sampling_metadata
.
selected_token_indices
,
"multi_modal_kwargs"
:
multi_modal_kwargs
,
}
metadata_dict
.
update
(
attn_metadata
.
asdict_zerocopy
())
broadcast_tensor_dict
(
metadata_dict
,
src
=
0
)
else
:
metadata_dict
=
broadcast_tensor_dict
(
src
=
0
)
input_tokens
=
metadata_dict
.
pop
(
"input_tokens"
)
input_positions
=
metadata_dict
.
pop
(
"input_positions"
)
selected_token_indices
=
metadata_dict
.
pop
(
"selected_token_indices"
)
multi_modal_kwargs
=
metadata_dict
.
pop
(
"multi_modal_kwargs"
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
**
metadata_dict
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
None
,
num_prompts
=
0
,
)
return
ModelInputForXPU
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
attn_metadata
=
attn_metadata
,
sampling_metadata
=
sampling_metadata
,
multi_modal_kwargs
=
multi_modal_kwargs
)
tensor_dict
:
Dict
[
str
,
Any
])
->
ModelInputForXPUWithSamplingMetadata
:
return
(
ModelInputForXPUWithSamplingMetadata
.
from_broadcasted_tensor_dict
(
tensor_dict
,
attn_backend
=
self
.
attn_backend
,
))
def
_prepare_
decode
(
def
_prepare_
model_input_tensors
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
ModelInputForXPUWithSamplingMetadata
:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
"""
builder
=
self
.
_builder_cls
(
weakref
.
proxy
(
self
),
finished_requests_ids
)
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
not
seq_group_metadata
.
is_prompt
assert
seq_group_metadata
.
token_chunk_size
==
1
builder
.
add_seq_group
(
seq_group_metadata
)
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
return
builder
.
build
()
# type: ignore
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
(
generation_token
)
seq_len
=
seq_data
.
get_len
()
position
=
seq_len
-
1
input_positions
.
append
(
position
)
seq_len
=
seq_len
if
self
.
sliding_window
is
None
else
min
(
seq_len
,
self
.
sliding_window
)
seq_lens
.
append
(
seq_len
)
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
if
self
.
sliding_window
is
not
None
:
sliding_window_blocks
=
(
self
.
sliding_window
//
self
.
block_size
)
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
max_decode_seq_len
=
max
(
seq_lens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
block_tables
=
make_tensor_with_pad
(
block_tables
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
self
.
device
,
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
slot_mapping
=
slot_mapping
,
seq_lens
=
seq_lens
,
seqlen_q
=
None
,
max_seqlen
=
None
,
seq_lens_tensor
=
seq_lens_tensor
,
max_decode_seq_len
=
max_decode_seq_len
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
len
(
input_tokens
),
num_prefills
=
0
,
block_tables
=
block_tables
,
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
)
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
)
->
ModelInputForXPUWithSamplingMetadata
:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
"""
model_input
=
self
.
_prepare_model_input_tensors
(
seq_group_metadata_list
,
finished_requests_ids
)
# Sampling metadata is only required for the final pp group
generators
=
self
.
get_generators
(
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
model_input
.
seq_lens
,
model_input
.
query_lens
,
self
.
device
,
pin_memory
=
False
,
generators
=
generators
)
return
dataclasses
.
replace
(
model_input
,
sampling_metadata
=
sampling_metadata
,
virtual_engine
=
virtual_engine
)
@
torch
.
inference_mode
()
def
execute_model
(
self
,
model_input
:
ModelInputForXPU
,
model_input
:
ModelInputForXPU
WithSamplingMetadata
,
kv_caches
:
List
[
torch
.
Tensor
],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
...
...
@@ -372,20 +533,21 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
"XPUModelRunner does not support multi-step execution."
)
model_executable
=
self
.
model
execute_model_kwargs
=
{
"input_ids"
:
model_input
.
input_tokens
,
"positions"
:
model_input
.
input_positions
,
"kv_caches"
:
kv_caches
,
"attn_metadata"
:
model_input
.
attn_metadata
,
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
):
model_forward_start_time
=
time
.
time
()
hidden_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
model_input
.
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalInputs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
device
=
self
.
device
)
,
}
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
device
=
self
.
device
)
)
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
):
model_forward_end_time
=
time
.
time
(
)
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
...
...
@@ -396,109 +558,19 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
return
[]
# Sample the next token.
output
=
self
.
model
.
sample
(
output
:
SamplerOutput
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
model_input
.
sampling_metadata
,
)
return
[
output
]
def
_prepare_prompt
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
AttentionMetadata
,
List
[
int
],
BatchedTensorInputs
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
int
]
=
[]
input_positions
:
List
[
int
]
=
[]
slot_mapping
:
List
[
int
]
=
[]
seq_lens
:
List
[
int
]
=
[]
multi_modal_inputs_list
:
List
[
MultiModalInputs
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
computed_len
=
seq_data
.
get_num_computed_tokens
()
seq_len
=
len
(
prompt_tokens
)
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
and
output
is
not
None
):
model_forward_time
=
(
model_forward_end_time
-
model_forward_start_time
)
# If there are multiple workers, we are still tracking the latency
# from the start time of the driver worker to the end time of the
# driver worker. The model forward time will then end up covering
# the communication time as well.
output
.
model_forward_time
=
model_forward_time
seq_lens
.
append
(
seq_len
)
# Prompt token num
input_tokens
.
extend
(
prompt_tokens
)
# Token ids
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions
.
extend
(
list
(
range
(
computed_len
,
seq_len
)))
mm_data
=
seq_group_metadata
.
multi_modal_data
if
mm_data
:
mm_kwargs
=
self
.
multi_modal_input_mapper
(
mm_data
)
multi_modal_inputs_list
.
append
(
mm_kwargs
)
if
seq_group_metadata
.
block_tables
is
None
:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping
.
extend
([
_PAD_SLOT_ID
]
*
seq_len
)
continue
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
start_idx
=
max
(
0
,
seq_len
-
self
.
sliding_window
)
for
i
in
range
(
computed_len
,
seq_len
):
if
i
<
start_idx
:
slot_mapping
.
append
(
_PAD_SLOT_ID
)
continue
block_number
=
block_table
[
i
//
self
.
block_size
]
# type: ignore
block_offset
=
i
%
self
.
block_size
# type: ignore
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
(
slot
)
num_prompt_tokens
=
len
(
input_tokens
)
input_tokens
=
torch
.
tensor
(
input_tokens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
input_positions
=
torch
.
tensor
(
input_positions
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
# type: ignore
max_seqlen
=
max
(
seq_lens
)
tmp
=
[
0
]
tmp
.
extend
(
seq_lens
)
seqlen
=
torch
.
tensor
(
tmp
)
seqlen_q
=
torch
.
cumsum
(
seqlen
,
dim
=
0
).
to
(
device
=
self
.
device
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
slot_mapping
=
slot_mapping
,
seq_lens
=
seq_lens
,
seqlen_q
=
seqlen_q
,
max_seqlen
=
max_seqlen
,
seq_lens_tensor
=
None
,
max_decode_seq_len
=
None
,
num_prefills
=
len
(
seq_lens
),
num_prefill_tokens
=
num_prompt_tokens
,
num_decode_tokens
=
0
,
block_tables
=
torch
.
tensor
([],
device
=
self
.
device
,
dtype
=
torch
.
int
),
)
multi_modal_kwargs
=
MultiModalInputs
.
batch
(
multi_modal_inputs_list
)
return
(
input_tokens
,
input_positions
,
attn_metadata
,
seq_lens
,
multi_modal_kwargs
)
return
[
output
]
vllm/worker/xpu_worker.py
View file @
fc5ebbd1
...
...
@@ -9,8 +9,8 @@ import torch
import
torch.distributed
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
...
...
@@ -46,7 +46,6 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
rank
:
int
,
distributed_init_method
:
str
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
multimodal_config
:
Optional
[
MultiModalConfig
]
=
None
,
speculative_config
:
Optional
[
SpeculativeConfig
]
=
None
,
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
,
is_driver_worker
:
bool
=
False
,
...
...
@@ -73,8 +72,6 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
assert
rank
%
parallel_config
.
tensor_parallel_size
==
0
,
\
"Driver worker should be rank 0 of tensor parallel group."
self
.
multimodal_config
=
multimodal_config
self
.
model_runner
=
XPUModelRunner
(
# type: ignore
model_config
,
parallel_config
,
...
...
@@ -85,7 +82,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
lora_config
=
self
.
lora_config
,
kv_cache_dtype
=
self
.
cache_config
.
cache_dtype
,
is_driver_worker
=
is_driver_worker
,
multimodal_config
=
multimodal
_config
,
observability_config
=
self
.
observability
_config
,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment