Unverified Commit 8f0d7eae authored by Fanli Lin's avatar Fanli Lin Committed by GitHub
Browse files

[XPU] Fix OOM issue for data parallel with Ray backend (#22500)


Signed-off-by: default avatarFanli Lin <fanli.lin@intel.com>
Signed-off-by: default avatarFanli Lin <fanli0116@gmail.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent e0394076
...@@ -39,7 +39,8 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, ...@@ -39,7 +39,8 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, EngineCoreRequestType,
ReconfigureDistributedRequest, ReconfigureRankType, ReconfigureDistributedRequest, ReconfigureRankType,
UtilityOutput, UtilityResult) UtilityOutput, UtilityResult)
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.v1.engine.utils import (EngineHandshakeMetadata, EngineZmqAddresses,
get_device_indices)
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
...@@ -1169,22 +1170,30 @@ class DPEngineCoreActor(DPEngineCoreProc): ...@@ -1169,22 +1170,30 @@ class DPEngineCoreActor(DPEngineCoreProc):
# https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501 # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
# and get_accelerator_ids_for_accelerator_resource() in worker.py # and get_accelerator_ids_for_accelerator_resource() in worker.py
# of ray. # of ray.
self._set_cuda_visible_devices(vllm_config, local_dp_rank) self._set_visible_devices(vllm_config, local_dp_rank)
super().__init__(vllm_config, local_client, "", executor_class, super().__init__(vllm_config, local_client, "", executor_class,
log_stats) log_stats)
def _set_cuda_visible_devices(self, vllm_config: VllmConfig, def _set_visible_devices(self, vllm_config: VllmConfig,
local_dp_rank: int): local_dp_rank: int):
from vllm.platforms import current_platform from vllm.platforms import current_platform
device_control_env_var = current_platform.device_control_env_var if current_platform.is_xpu():
pass
else:
device_control_env_var = current_platform.device_control_env_var
self._set_cuda_visible_devices(vllm_config, local_dp_rank,
device_control_env_var)
def _set_cuda_visible_devices(self, vllm_config: VllmConfig,
local_dp_rank: int,
device_control_env_var: str):
world_size = vllm_config.parallel_config.world_size world_size = vllm_config.parallel_config.world_size
# Set CUDA_VISIBLE_DEVICES or equivalent. # Set CUDA_VISIBLE_DEVICES or equivalent.
try: try:
os.environ[device_control_env_var] = ",".join( value = get_device_indices(device_control_env_var, local_dp_rank,
str(current_platform.device_id_to_physical_device_id(i)) world_size)
for i in range(local_dp_rank * os.environ[device_control_env_var] = value
world_size, (local_dp_rank + 1) * world_size))
except IndexError as e: except IndexError as e:
raise Exception( raise Exception(
f"Error setting {device_control_env_var}: " f"Error setting {device_control_env_var}: "
......
...@@ -164,19 +164,33 @@ def set_device_control_env_var(vllm_config: VllmConfig, ...@@ -164,19 +164,33 @@ def set_device_control_env_var(vllm_config: VllmConfig,
""" """
world_size = vllm_config.parallel_config.world_size world_size = vllm_config.parallel_config.world_size
evar = current_platform.device_control_env_var evar = current_platform.device_control_env_var
value = get_device_indices(evar, local_dp_rank, world_size)
with patch.dict(os.environ, values=((evar, value), )):
yield
def get_device_indices(device_control_env_var: str, local_dp_rank: int,
world_size: int):
"""
Returns a comma-separated string of device indices for the specified
data parallel rank.
For example, if world_size=2 and local_dp_rank=1, and there are 4 devices,
this will select devices 2 and 3 for local_dp_rank=1.
"""
try: try:
value = ",".join( value = ",".join(
str(current_platform.device_id_to_physical_device_id(i)) str(current_platform.device_id_to_physical_device_id(i))
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
world_size)) world_size))
except IndexError as e: except IndexError as e:
raise Exception(f"Error setting {evar}: " raise Exception(f"Error setting {device_control_env_var}: "
f"local range: [{local_dp_rank * world_size}, " f"local range: [{local_dp_rank * world_size}, "
f"{(local_dp_rank + 1) * world_size}) " f"{(local_dp_rank + 1) * world_size}) "
"base value: " "base value: "
f"\"{os.getenv(evar)}\"") from e f"\"{os.getenv(device_control_env_var)}\"") from e
with patch.dict(os.environ, values=((evar, value), )): return value
yield
class CoreEngineActorManager: class CoreEngineActorManager:
...@@ -254,6 +268,19 @@ class CoreEngineActorManager: ...@@ -254,6 +268,19 @@ class CoreEngineActorManager:
dp_vllm_config = copy.deepcopy(vllm_config) dp_vllm_config = copy.deepcopy(vllm_config)
dp_vllm_config.parallel_config.placement_group = pg dp_vllm_config.parallel_config.placement_group = pg
local_client = index < local_engine_count local_client = index < local_engine_count
# Ray XPU known issue: dpctl initializes the GPU runtime early, so
# setting device env vars in Ray actor's initialization method
# will not affect device selection. See:
# https://github.com/ray-project/ray/blob/master/python/ray/_private/accelerators/intel_gpu.py#L56 # noqa: E501
if current_platform.is_xpu():
device_evar = current_platform.device_control_env_var
device_indices = get_device_indices(device_evar, local_index,
world_size)
actor_env_vars = self.env_vars_dict.copy()
actor_env_vars[device_evar] = device_indices
runtime_env = RuntimeEnv(env_vars=actor_env_vars)
actor = ray.remote(DPEngineCoreActor).options( actor = ray.remote(DPEngineCoreActor).options(
scheduling_strategy=PlacementGroupSchedulingStrategy( scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg, placement_group=pg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment