Unverified Commit 8438e056 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[Core] RayWorkerVllm --> WorkerWrapper to reduce duplication (#4024)

[Core] replace narrow-usage RayWorkerVllm to general WorkerWrapper to reduce code duplication (#4024)
parent 11d652bd
import multiprocessing import multiprocessing
import os
import pytest import pytest
import torch import torch
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId) ncclGetUniqueId)
from vllm.utils import update_environment_variables
def distributed_run(fn, world_size): def distributed_run(fn, world_size):
number_of_processes = world_size number_of_processes = world_size
processes = [] processes = []
for i in range(number_of_processes): for i in range(number_of_processes):
env = os.environ.copy() env = {}
env['RANK'] = str(i) env['RANK'] = str(i)
env['LOCAL_RANK'] = str(i) env['LOCAL_RANK'] = str(i)
env['WORLD_SIZE'] = str(number_of_processes) env['WORLD_SIZE'] = str(number_of_processes)
...@@ -32,8 +32,7 @@ def update_env(fn): ...@@ -32,8 +32,7 @@ def update_env(fn):
# so we need to pass the environment variables as arguments # so we need to pass the environment variables as arguments
# and update the environment variables in the function # and update the environment variables in the function
def wrapper(env): def wrapper(env):
import os update_environment_variables(env)
os.environ.update(env)
fn() fn()
return wrapper return wrapper
......
import pickle import pickle
from typing import Callable, List, Optional, Tuple from typing import List, Optional, Tuple
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_ip, is_hip, set_cuda_visible_devices from vllm.utils import get_ip, is_hip
from vllm.worker.worker import Worker from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)
try: try:
import ray import ray
class RayWorkerVllm: class RayWorkerWrapper(WorkerWrapperBase):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be """Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def __init__(self, init_cached_hf_modules=False) -> None: def __init__(self, *args, **kwargs) -> None:
if init_cached_hf_modules: super().__init__(*args, **kwargs)
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
self._worker: Optional[Worker] = None
# Since the compiled DAG runs a main execution # Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device. # in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on # The flag indicates is set_device is called on
# that thread. # that thread.
self.compiled_dag_cuda_device_set = False self.compiled_dag_cuda_device_set = False
def init_worker(self, worker_init_fn: Callable[[], Worker]):
self._worker = worker_init_fn()
@property
def worker(self) -> Worker:
assert self._worker is not None
return self._worker
def __getattr__(self, name):
return getattr(self.worker, name)
def execute_method(self, method, *args, **kwargs):
try:
executor = getattr(self, method)
return executor(*args, **kwargs)
except Exception as e:
# exceptions in ray worker may cause deadlock
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (f"Error executing method {method}. "
"This might cause deadlock in distributed execution.")
logger.exception(msg)
raise e
def get_node_ip(self) -> str: def get_node_ip(self) -> str:
return get_ip() return get_ip()
...@@ -58,9 +31,6 @@ try: ...@@ -58,9 +31,6 @@ try:
gpu_ids = ray.get_gpu_ids() gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids return node_id, gpu_ids
def set_cuda_visible_devices(self, device_ids) -> None:
set_cuda_visible_devices(device_ids)
def execute_model_compiled_dag_remote(self, ignored): def execute_model_compiled_dag_remote(self, ignored):
"""Used only when compiled DAG is enabled.""" """Used only when compiled DAG is enabled."""
import torch import torch
...@@ -77,7 +47,7 @@ except ImportError as e: ...@@ -77,7 +47,7 @@ except ImportError as e:
"For distributed inference, please install Ray with " "For distributed inference, please install Ray with "
"`pip install ray`.") "`pip install ray`.")
ray = None # type: ignore ray = None # type: ignore
RayWorkerVllm = None # type: ignore RayWorkerWrapper = None # type: ignore
def initialize_ray_cluster( def initialize_ray_cluster(
......
import asyncio import asyncio
import copy
import os import os
import pickle import pickle
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.engine.ray_utils import RayWorkerWrapper, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async, set_cuda_visible_devices) make_async)
if ray is not None: if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
...@@ -74,9 +73,9 @@ class RayGPUExecutor(ExecutorBase): ...@@ -74,9 +73,9 @@ class RayGPUExecutor(ExecutorBase):
# The driver dummy worker does not actually use any resources. # The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker. # It holds the resource for the driver worker.
self.driver_dummy_worker: RayWorkerVllm = None self.driver_dummy_worker: RayWorkerWrapper = None
# The remaining workers are the actual ray actors. # The remaining workers are the actual ray actors.
self.workers: List[RayWorkerVllm] = [] self.workers: List[RayWorkerWrapper] = []
if self.parallel_config.ray_workers_use_nsight: if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight( ray_remote_kwargs = self._configure_ray_workers_use_nsight(
...@@ -97,13 +96,20 @@ class RayGPUExecutor(ExecutorBase): ...@@ -97,13 +96,20 @@ class RayGPUExecutor(ExecutorBase):
num_gpus=num_gpus, num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy, scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs, **ray_remote_kwargs,
)(RayWorkerVllm).remote(self.model_config.trust_remote_code) )(RayWorkerWrapper).remote(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
)
worker_ip = ray.get(worker.get_node_ip.remote()) worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None: if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it # If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process. # as the resource holder for the driver process.
self.driver_dummy_worker = worker self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
)
else: else:
# Else, added to the list of workers. # Else, added to the list of workers.
self.workers.append(worker) self.workers.append(worker)
...@@ -115,82 +121,56 @@ class RayGPUExecutor(ExecutorBase): ...@@ -115,82 +121,56 @@ class RayGPUExecutor(ExecutorBase):
"GPU node.") "GPU node.")
# Get the set of GPU IDs used on each node. # Get the set of GPU IDs used on each node.
driver_node_id, driver_gpu_ids = ray.get( worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
self.driver_dummy_worker.get_node_and_gpu_ids.remote()) use_dummy_driver=True)
worker_node_and_gpu_ids = ray.get(
[worker.get_node_and_gpu_ids.remote() for worker in self.workers])
node_workers = defaultdict(list) node_workers = defaultdict(list)
node_gpus = defaultdict(list) node_gpus = defaultdict(list)
node_workers[driver_node_id].append(0) for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_gpus[driver_node_id].extend(driver_gpu_ids)
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
start=1):
node_workers[node_id].append(i) node_workers[node_id].append(i)
node_gpus[node_id].extend(gpu_ids) node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items(): for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids) node_gpus[node_id] = sorted(gpu_ids)
# Set CUDA_VISIBLE_DEVICES for the driver and workers. # Set CUDA_VISIBLE_DEVICES for the driver and workers.
set_cuda_visible_devices(node_gpus[driver_node_id]) all_args_to_update_environment_variables = []
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids): for (node_id, _) in worker_node_and_gpu_ids:
worker.set_cuda_visible_devices.remote(node_gpus[node_id]) all_args_to_update_environment_variables.append([{
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id]))
}])
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port()) driver_ip, get_open_port())
# Lazy import the Worker to avoid importing torch.cuda/xformers def collect_arg_helper_func(**kwargs):
# before CUDA_VISIBLE_DEVICES is set in the Worker # avoid writing `{"name": value}` manually
from vllm.worker.worker import Worker return kwargs
model_config = copy.deepcopy(self.model_config) init_worker_all_kwargs = []
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config) # Initialize the actual workers inside worker wrapper.
load_config = copy.deepcopy(self.load_config) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ):
device_config = copy.deepcopy(self.device_config)
lora_config = copy.deepcopy(self.lora_config)
cache_config = copy.deepcopy(self.cache_config)
vision_language_config = copy.deepcopy(self.vision_language_config)
# Initialize the actual workers with the Worker class.
for rank, (worker, (node_id, _)) in enumerate(
zip(self.workers, worker_node_and_gpu_ids),
start=1,
):
local_rank = node_workers[node_id].index(rank) local_rank = node_workers[node_id].index(rank)
worker.init_worker.remote( init_worker_all_kwargs.append(
lambda rank=rank, local_rank=local_rank: Worker( collect_arg_helper_func(
model_config=model_config, model_config=self.model_config,
parallel_config=parallel_config, parallel_config=self.parallel_config,
scheduler_config=scheduler_config, scheduler_config=self.scheduler_config,
device_config=device_config, device_config=self.device_config,
cache_config=cache_config, cache_config=self.cache_config,
load_config=load_config, load_config=self.load_config,
local_rank=local_rank, local_rank=local_rank,
rank=rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=lora_config, lora_config=self.lora_config,
vision_language_config=vision_language_config, vision_language_config=self.vision_language_config,
is_driver_worker=rank == 0,
)) ))
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
# Initialize the driver worker with the Worker class.
driver_rank = 0
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
self.driver_worker = Worker(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
local_rank=driver_local_rank,
rank=driver_rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
load_config=self.load_config,
is_driver_worker=True,
)
self._run_workers("init_device") self._run_workers("init_device")
self._run_workers( self._run_workers(
...@@ -279,13 +259,35 @@ class RayGPUExecutor(ExecutorBase): ...@@ -279,13 +259,35 @@ class RayGPUExecutor(ExecutorBase):
self, self,
method: str, method: str,
*args, *args,
driver_args: Optional[Tuple[Any, ...]] = None, driver_args: Optional[Tuple[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None,
all_args: Optional[List[List[Any]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False, use_ray_compiled_dag: bool = False,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers.""" """Runs the given method on all workers.
all_args and all_kwargs are used to pass heterogeneous arguments,
i.e. different arguments for each worker.
"""
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
# for mypy type checking
assert driver_args is not None
assert driver_kwargs is not None
if all_args is None:
all_args = [driver_args] + [args] * len(self.workers)
if all_kwargs is None:
all_kwargs = [driver_kwargs] + [kwargs] * len(self.workers)
# for mypy type checking
assert all_args is not None
assert all_kwargs is not None
if max_concurrent_workers: if max_concurrent_workers:
raise NotImplementedError( raise NotImplementedError(
...@@ -299,8 +301,10 @@ class RayGPUExecutor(ExecutorBase): ...@@ -299,8 +301,10 @@ class RayGPUExecutor(ExecutorBase):
else: else:
# Start the ray workers first. # Start the ray workers first.
ray_worker_outputs = [ ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs) worker.execute_method.remote(method, *worker_args,
for worker in self.workers **worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_args[1:], all_kwargs[1:])
] ]
if driver_args is None: if driver_args is None:
...@@ -309,9 +313,13 @@ class RayGPUExecutor(ExecutorBase): ...@@ -309,9 +313,13 @@ class RayGPUExecutor(ExecutorBase):
driver_kwargs = kwargs driver_kwargs = kwargs
# Start the driver worker after all the ray workers. # Start the driver worker after all the ray workers.
driver_worker_output = getattr(self.driver_worker, if not use_dummy_driver:
method)(*driver_args, **driver_kwargs) driver_worker_output = self.driver_worker.execute_method(
method, *all_args[0], **all_kwargs[0])
else:
driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *all_args[0], **all_kwargs[0]))
# Get the results of the ray workers. # Get the results of the ray workers.
if self.workers: if self.workers:
if use_ray_compiled_dag: if use_ray_compiled_dag:
...@@ -386,8 +394,12 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): ...@@ -386,8 +394,12 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
driver_kwargs = kwargs driver_kwargs = kwargs
# Run the driver worker asynchronously. # Run the driver worker asynchronously.
driver_executor = make_async(getattr(self.driver_worker, method)) def helper():
coros.append(driver_executor(*driver_args, **driver_kwargs)) return self.driver_worker.execute_method(method, *driver_args,
**driver_kwargs)
driver_executor = make_async(helper)
coros.append(driver_executor())
# Run the ray workers asynchronously. # Run the ray workers asynchronously.
for worker in self.workers: for worker in self.workers:
......
...@@ -271,8 +271,12 @@ def get_open_port() -> int: ...@@ -271,8 +271,12 @@ def get_open_port() -> int:
return s.getsockname()[1] return s.getsockname()[1]
def set_cuda_visible_devices(device_ids: List[int]) -> None: def update_environment_variables(envs: Dict[str, str]):
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) for k, v in envs.items():
if k in os.environ:
logger.warning(f"Overwriting environment variable {k} "
f"from '{os.environ[k]}' to '{v}'")
os.environ[k] = v
def chunk_list(lst, chunk_size): def chunk_list(lst, chunk_size):
...@@ -505,3 +509,11 @@ def merge_dicts(dict1: Dict[Any, List[Any]], ...@@ -505,3 +509,11 @@ def merge_dicts(dict1: Dict[Any, List[Any]],
merged_dict[key].extend(value) merged_dict[key].extend(value)
return dict(merged_dict) return dict(merged_dict)
def init_cached_hf_modules():
"""
Lazy initialization of the Hugging Face modules.
"""
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
...@@ -138,7 +138,10 @@ class CPUWorker(LoraNotSupportedWorkerBase): ...@@ -138,7 +138,10 @@ class CPUWorker(LoraNotSupportedWorkerBase):
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert self.rank == 0, "The driver worker must have rank 0."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner = CPUModelRunner(model_config, self.model_runner = CPUModelRunner(model_config,
parallel_config, parallel_config,
scheduler_config, scheduler_config,
......
...@@ -29,6 +29,10 @@ class NeuronWorker(LoraNotSupportedWorkerBase): ...@@ -29,6 +29,10 @@ class NeuronWorker(LoraNotSupportedWorkerBase):
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.cache_config = cache_config self.cache_config = cache_config
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner = NeuronModelRunner(model_config, parallel_config, self.model_runner = NeuronModelRunner(model_config, parallel_config,
scheduler_config, device_config) scheduler_config, device_config)
......
...@@ -60,6 +60,10 @@ class Worker(WorkerBase): ...@@ -60,6 +60,10 @@ class Worker(WorkerBase):
if self.is_driver_worker: if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert self.rank == 0, "The driver worker must have rank 0."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.vision_language_config = vision_language_config self.vision_language_config = vision_language_config
if self.vision_language_config: if self.vision_language_config:
assert not self.lora_config, ( assert not self.lora_config, (
......
import importlib
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import update_environment_variables
logger = init_logger(__name__)
class WorkerBase(ABC): class WorkerBase(ABC):
...@@ -82,3 +88,53 @@ class LoraNotSupportedWorkerBase(WorkerBase): ...@@ -82,3 +88,53 @@ class LoraNotSupportedWorkerBase(WorkerBase):
def list_loras(self) -> List[int]: def list_loras(self) -> List[int]:
raise ValueError(f"{type(self)} does not support LoRA") raise ValueError(f"{type(self)} does not support LoRA")
class WorkerWrapperBase:
"""
The whole point of this class is to lazily initialize the worker.
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
"""
def __init__(self,
worker_module_name=None,
worker_class_name=None) -> None:
self.worker_module_name = worker_module_name
self.worker_class_name = worker_class_name
self.worker = None
def update_environment_variables(self, envs: Dict[str, str]) -> None:
key = 'CUDA_VISIBLE_DEVICES'
if key in envs and key in os.environ:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
# suppress the warning in `update_environment_variables`
del os.environ[key]
update_environment_variables(envs)
def init_worker(self, *args, **kwargs):
"""
Actual initialization of the worker class.
Arguments are passed to the worker class constructor.
"""
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
self.worker = worker_class(*args, **kwargs)
def execute_method(self, method, *args, **kwargs):
try:
if hasattr(self, method):
executor = getattr(self, method)
else:
executor = getattr(self.worker, method)
return executor(*args, **kwargs)
except Exception as e:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (f"Error executing method {method}. "
"This might cause deadlock in distributed execution.")
logger.exception(msg)
raise e
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment