Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
076169f6
Unverified
Commit
076169f6
authored
Aug 28, 2024
by
Kunshang Ji
Committed by
GitHub
Aug 27, 2024
Browse files
[Hardware][Intel GPU] Add intel GPU pipeline parallel support. (#7810)
parent
9db64213
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
82 additions
and
19 deletions
+82
-19
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+5
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+7
-0
vllm/executor/multiproc_gpu_executor.py
vllm/executor/multiproc_gpu_executor.py
+22
-16
vllm/executor/multiproc_xpu_executor.py
vllm/executor/multiproc_xpu_executor.py
+26
-0
vllm/worker/xpu_model_runner.py
vllm/worker/xpu_model_runner.py
+16
-3
vllm/worker/xpu_worker.py
vllm/worker/xpu_worker.py
+6
-0
No files found.
vllm/engine/async_llm_engine.py
View file @
076169f6
...
...
@@ -666,6 +666,11 @@ class AsyncLLMEngine:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_xpu_executor
import
RayXPUExecutorAsync
executor_class
=
RayXPUExecutorAsync
elif
distributed_executor_backend
==
"mp"
:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.multiproc_xpu_executor
import
(
MultiprocessingXPUExecutorAsync
)
executor_class
=
MultiprocessingXPUExecutorAsync
else
:
raise
RuntimeError
(
"Not supported distributed execution model on XPU device."
)
...
...
vllm/engine/llm_engine.py
View file @
076169f6
...
...
@@ -472,6 +472,13 @@ class LLMEngine:
initialize_ray_cluster
(
engine_config
.
parallel_config
)
from
vllm.executor.ray_xpu_executor
import
RayXPUExecutor
executor_class
=
RayXPUExecutor
elif
distributed_executor_backend
==
"mp"
:
# FIXME(kunshang):
# spawn needs calling `if __name__ == '__main__':``
# fork is not supported for xpu start new process.
logger
.
error
(
"Both start methods (spawn and fork) have issue "
"on XPU if you use mp backend, Please try ray instead."
)
else
:
from
vllm.executor.xpu_executor
import
XPUExecutor
executor_class
=
XPUExecutor
...
...
vllm/executor/multiproc_gpu_executor.py
View file @
076169f6
...
...
@@ -30,16 +30,12 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
self
.
_check_executor_parameters
()
# Create the parallel GPU workers.
world_size
=
self
.
parallel_config
.
world_size
tensor_parallel_size
=
self
.
parallel_config
.
tensor_parallel_size
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if
"CUDA_VISIBLE_DEVICES"
not
in
os
.
environ
:
update_environment_variables
({
"CUDA_VISIBLE_DEVICES"
:
(
","
.
join
(
map
(
str
,
range
(
world_size
))))
})
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
os
.
environ
[
"VLLM_INSTANCE_ID"
]
=
get_vllm_instance_id
()
...
...
@@ -68,16 +64,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
if
world_size
>
1
:
maybe_set_triton_cache_manager
()
cuda_device_count
=
cuda_device_count_stateless
()
# Use confusing message for more common TP-only case.
assert
tensor_parallel_size
<=
cuda_device_count
,
(
f
"please set tensor_parallel_size (
{
tensor_parallel_size
}
) "
f
"to less than max local gpu count (
{
cuda_device_count
}
)"
)
assert
world_size
<=
cuda_device_count
,
(
f
"please ensure that world_size (
{
world_size
}
) "
f
"is less than than max local gpu count (
{
cuda_device_count
}
)"
)
# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# 127.0.0.1 for communication.
...
...
@@ -139,6 +125,26 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
max_concurrent_workers
=
self
.
parallel_config
.
max_parallel_loading_workers
)
def
_check_executor_parameters
(
self
):
world_size
=
self
.
parallel_config
.
tensor_parallel_size
tensor_parallel_size
=
self
.
parallel_config
.
tensor_parallel_size
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
if
"CUDA_VISIBLE_DEVICES"
not
in
os
.
environ
:
update_environment_variables
({
"CUDA_VISIBLE_DEVICES"
:
(
","
.
join
(
map
(
str
,
range
(
world_size
))))
})
cuda_device_count
=
cuda_device_count_stateless
()
# Use confusing message for more common TP-only case.
assert
tensor_parallel_size
<=
cuda_device_count
,
(
f
"please set tensor_parallel_size (
{
tensor_parallel_size
}
) "
f
"to less than max local gpu count (
{
cuda_device_count
}
)"
)
assert
world_size
<=
cuda_device_count
,
(
f
"please ensure that world_size (
{
world_size
}
) "
f
"is less than than max local gpu count (
{
cuda_device_count
}
)"
)
def
shutdown
(
self
):
if
(
worker_monitor
:
=
getattr
(
self
,
"worker_monitor"
,
None
))
is
not
None
:
...
...
vllm/executor/multiproc_xpu_executor.py
0 → 100644
View file @
076169f6
import
vllm.envs
as
envs
from
vllm.executor.multiproc_gpu_executor
import
(
MultiprocessingGPUExecutor
,
MultiprocessingGPUExecutorAsync
)
from
vllm.executor.xpu_executor
import
XPUExecutor
from
vllm.logger
import
init_logger
from
vllm.utils
import
make_async
logger
=
init_logger
(
__name__
)
class
MultiprocessingXPUExecutor
(
MultiprocessingGPUExecutor
,
XPUExecutor
):
"""Python multiprocessing-based multi-XPU executor"""
def
_check_executor_parameters
(
self
):
mp_method
=
envs
.
VLLM_WORKER_MULTIPROC_METHOD
if
mp_method
!=
"spawn"
:
raise
RuntimeError
(
"XPU multiprocess executor only support spawn as mp method"
)
class
MultiprocessingXPUExecutorAsync
(
MultiprocessingXPUExecutor
,
MultiprocessingGPUExecutorAsync
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
driver_exec_model
=
make_async
(
self
.
driver_worker
.
execute_model
)
vllm/worker/xpu_model_runner.py
View file @
076169f6
...
...
@@ -12,6 +12,7 @@ from vllm.attention import get_attn_backend
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.distributed
import
get_pp_group
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
...
...
@@ -439,9 +440,11 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
"Setting it to the minimum value of 1."
,
expr
)
max_num_seqs
=
1
batch_size
=
0
for
group_id
in
range
(
max_num_seqs
):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
batch_size
+=
seq_len
seq_data
,
dummy_multi_modal_data
=
self
.
input_registry
\
.
dummy_data_for_profiling
(
self
.
model_config
,
...
...
@@ -465,7 +468,13 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
finished_requests_ids
=
[
seq
.
request_id
for
seq
in
seqs
]
model_input
=
self
.
prepare_model_input
(
seqs
,
finished_requests_ids
=
finished_requests_ids
)
self
.
execute_model
(
model_input
,
kv_caches
)
intermediate_tensors
=
None
if
not
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
self
.
model
.
make_empty_intermediate_tensors
(
batch_size
=
batch_size
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
)
self
.
execute_model
(
model_input
,
kv_caches
,
intermediate_tensors
)
torch
.
xpu
.
synchronize
()
return
...
...
@@ -537,7 +546,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
and
self
.
observability_config
.
collect_model_forward_time
):
model_forward_start_time
=
time
.
time
()
hidden_states
=
model_executable
(
hidden_
or_intermediate_
states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
kv_caches
=
kv_caches
,
...
...
@@ -545,12 +554,16 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
intermediate_tensors
=
intermediate_tensors
,
**
MultiModalInputs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
or
{},
device
=
self
.
device
))
# Compute the logits in the last pipeline stage.
if
not
get_pp_group
().
is_last_rank
:
return
hidden_or_intermediate_states
if
(
self
.
observability_config
is
not
None
and
self
.
observability_config
.
collect_model_forward_time
):
model_forward_end_time
=
time
.
time
()
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
logits
=
self
.
model
.
compute_logits
(
hidden_
or_intermediate_
states
,
model_input
.
sampling_metadata
)
# Only perform sampling in the driver worker.
...
...
vllm/worker/xpu_worker.py
View file @
076169f6
...
...
@@ -14,6 +14,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
SpeculativeConfig
)
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.utils
import
is_xpu
...
...
@@ -198,3 +199,8 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
if
parallel_config
.
pipeline_parallel_size
>
1
:
# torch-ccl xpu need a collective API warm up
# before calling send/recv API
get_pp_group
().
all_reduce
(
torch
.
zeros
(
1
).
xpu
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment