Unverified Commit 9925c179 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Ray placement group support (#397)

parent 8c4b2592
ninja # For faster builds. ninja # For faster builds.
psutil psutil
ray ray >= 2.5.1
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
numpy numpy
torch >= 2.0.0 torch >= 2.0.0
......
...@@ -226,14 +226,14 @@ class AsyncLLMEngine: ...@@ -226,14 +226,14 @@ class AsyncLLMEngine:
engine_configs = engine_args.create_engine_configs() engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2] parallel_config = engine_configs[2]
# Initialize the cluster. # Initialize the cluster.
distributed_init_method, devices = initialize_cluster( distributed_init_method, placement_group = initialize_cluster(
parallel_config, engine_args.engine_use_ray) parallel_config, engine_args.engine_use_ray)
# Create the async LLM engine. # Create the async LLM engine.
engine = cls(engine_args.worker_use_ray, engine = cls(engine_args.worker_use_ray,
engine_args.engine_use_ray, engine_args.engine_use_ray,
*engine_configs, *engine_configs,
distributed_init_method, distributed_init_method,
devices, placement_group,
log_requests=not engine_args.disable_log_requests, log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats) log_stats=not engine_args.disable_log_stats)
return engine return engine
import time import time
from typing import Any, List, Optional from functools import partial
from typing import Any, List, Optional, TYPE_CHECKING
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.ray_utils import DeviceID, initialize_cluster, ray from vllm.engine.ray_utils import initialize_cluster, ray, RayWorker
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -13,7 +14,13 @@ from vllm.sequence import Sequence, SequenceGroup, SequenceStatus ...@@ -13,7 +14,13 @@ from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.transformers_utils.tokenizer import (detokenize_incrementally, from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer) get_tokenizer)
from vllm.utils import Counter from vllm.utils import Counter
from vllm.worker.worker import Worker
if ray:
from ray.air.util.torch_dist import init_torch_dist_process_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -54,7 +61,7 @@ class LLMEngine: ...@@ -54,7 +61,7 @@ class LLMEngine:
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
distributed_init_method: str, distributed_init_method: str,
stage_devices: List[List[DeviceID]], placement_group: Optional["PlacementGroup"],
log_stats: bool, log_stats: bool,
) -> None: ) -> None:
logger.info( logger.info(
...@@ -85,31 +92,73 @@ class LLMEngine: ...@@ -85,31 +92,73 @@ class LLMEngine:
self.seq_counter = Counter() self.seq_counter = Counter()
# Create the parallel GPU workers. # Create the parallel GPU workers.
self.workers: List[Worker] = [] if self.parallel_config.worker_use_ray:
assert len(stage_devices) == 1, "Only support one stage for now." self._init_workers_ray(placement_group)
for rank, node_resource, _ in stage_devices[0]: else:
worker_cls = Worker self._init_workers(distributed_init_method)
if self.parallel_config.worker_use_ray:
worker_cls = ray.remote(
num_cpus=0,
num_gpus=1,
resources={node_resource: 1e-3},
)(worker_cls).remote
worker = worker_cls(
model_config,
parallel_config,
scheduler_config,
rank,
distributed_init_method,
)
self.workers.append(worker)
# Profile the memory usage and initialize the cache. # Profile the memory usage and initialize the cache.
self._init_cache() self._init_cache()
# Create the scheduler. # Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats) self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)
def _init_workers(self, distributed_init_method: str):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
assert self.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
self.workers: List[Worker] = []
worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
0,
distributed_init_method,
)
self.workers.append(worker)
self._run_workers(
"init_model",
get_all_outputs=True,
)
def _init_workers_ray(self, placement_group: "PlacementGroup"):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
self.workers: List[Worker] = []
for bundle in placement_group.bundle_specs:
if not bundle.get("GPU", 0):
continue
worker = ray.remote(
num_cpus=0,
num_gpus=1,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True),
)(RayWorker).remote()
self.workers.append(worker)
# Initialize torch distributed process group for the workers.
init_torch_dist_process_group(self.workers, backend="nccl")
self._run_workers("init_worker",
get_all_outputs=True,
worker_init_fn=lambda: Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
None,
None,
))
self._run_workers(
"init_model",
get_all_outputs=True,
)
def _verify_args(self) -> None: def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config)
...@@ -152,11 +201,12 @@ class LLMEngine: ...@@ -152,11 +201,12 @@ class LLMEngine:
engine_configs = engine_args.create_engine_configs() engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2] parallel_config = engine_configs[2]
# Initialize the cluster. # Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config) distributed_init_method, placement_group = initialize_cluster(
parallel_config)
# Create the LLM engine. # Create the LLM engine.
engine = cls(*engine_configs, engine = cls(*engine_configs,
distributed_init_method, distributed_init_method,
devices, placement_group,
log_stats=not engine_args.disable_log_stats) log_stats=not engine_args.disable_log_stats)
return engine return engine
...@@ -326,9 +376,10 @@ class LLMEngine: ...@@ -326,9 +376,10 @@ class LLMEngine:
"""Runs the given method on all workers.""" """Runs the given method on all workers."""
all_outputs = [] all_outputs = []
for worker in self.workers: for worker in self.workers:
executor = getattr(worker, method)
if self.parallel_config.worker_use_ray: if self.parallel_config.worker_use_ray:
executor = executor.remote executor = partial(worker.execute_method.remote, method)
else:
executor = getattr(worker, method)
output = executor(*args, **kwargs) output = executor(*args, **kwargs)
all_outputs.append(output) all_outputs.append(output)
......
import socket import socket
from typing import List, Optional, Tuple from typing import Optional, Tuple, TYPE_CHECKING
from vllm.config import ParallelConfig
try: try:
import ray import ray
from ray.air.util.torch_dist import TorchDistributedWorker
class RayWorker(TorchDistributedWorker):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
def __init__(self) -> None:
self.worker = None
def init_worker(self, worker_init_fn):
self.worker = worker_init_fn()
def __getattr__(self, name):
return getattr(self.worker, name)
def execute_method(self, method, *args, **kwargs):
executor = getattr(self, method)
return executor(*args, **kwargs)
except ImportError: except ImportError:
ray = None ray = None
TorchDistributedWorker = None
from vllm.config import ParallelConfig if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
# rank, node resource (node IP), device id
DeviceID = Tuple[int, Optional[str], int]
def get_open_port(): def get_open_port():
...@@ -22,7 +42,7 @@ def initialize_cluster( ...@@ -22,7 +42,7 @@ def initialize_cluster(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
engine_use_ray: bool = False, engine_use_ray: bool = False,
ray_address: Optional[str] = None, ray_address: Optional[str] = None,
) -> Tuple[str, List[List[DeviceID]]]: ) -> Tuple[str, Optional["PlacementGroup"]]:
"""Initialize the distributed cluster probably with Ray. """Initialize the distributed cluster probably with Ray.
Args: Args:
...@@ -52,63 +72,36 @@ def initialize_cluster( ...@@ -52,63 +72,36 @@ def initialize_cluster(
# We need to setup the distributed init method to make sure # We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly. # the distributed megatron code (e.g., get world size) works correctly.
distributed_init_method = f"tcp://localhost:{port}" distributed_init_method = f"tcp://localhost:{port}"
all_stage_devices = [[(0, None, 0)]] return distributed_init_method, None
return distributed_init_method, all_stage_devices
current_placement_group = ray.util.get_current_placement_group()
# Assume we have a uniform cluster that each node has the same number of if current_placement_group:
# GPUs for now. # We are in a placement group
valid_node_resources = [] bundles = current_placement_group.bundle_specs
num_devices_per_node = None # Verify that we can use the placement group.
for node in ray.nodes(): gpu_bundles = 0
if (not node["Alive"]) or node["Resources"]["GPU"] <= 0: for bundle in bundles:
continue assert bundle.get("GPU", 0) > 1, (
if num_devices_per_node is None: "Placement group bundles cannot have more than 1 GPU")
num_devices_per_node = node["Resources"]["GPU"] if bundle.get("GPU", 0):
else: gpu_bundles += 1
assert num_devices_per_node == node["Resources"]["GPU"], ( if parallel_config.world_size > gpu_bundles:
"The number of GPUs per node is not uniform.")
for key in node["Resources"]:
if key.startswith("node:"):
valid_node_resources.append(key)
# Verify the parallel config.
num_nodes = len(valid_node_resources)
if parallel_config.world_size > num_nodes * num_devices_per_node:
raise ValueError(
"The number of required GPUs exceeds the total number of "
"available GPUs.")
if parallel_config.tensor_parallel_size >= num_devices_per_node:
if parallel_config.tensor_parallel_size % num_devices_per_node != 0:
raise ValueError( raise ValueError(
"The number of tensor parallelism is not divisible by the " "The number of required GPUs exceeds the total number of "
"number of GPUs per node.") "available GPUs in the placement group.")
else: else:
if num_devices_per_node % parallel_config.tensor_parallel_size != 0: num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
if parallel_config.world_size > num_gpus_in_cluster:
raise ValueError( raise ValueError(
"The number of GPUs per node is not divisible by the number " "The number of required GPUs exceeds the total number of "
"of tensor parallelism.") "available GPUs in the cluster.")
# Create a new placement group
# Assign GPUs to pipeline stages. current_placement_group = ray.util.placement_group([{
rank = 0 "GPU": 1
current_node_id = 0 }] * parallel_config.world_size)
current_device_id = 0 # Wait until PG is ready - this will block until all
distributed_init_method = None # requested resources are available, and will timeout
all_stage_devices = [] # if they cannot be provisioned.
ray.get(current_placement_group.ready(), timeout=1800)
for _ in range(parallel_config.pipeline_parallel_size):
stage_devices = [] return None, current_placement_group
for _ in range(parallel_config.tensor_parallel_size):
node_resource = valid_node_resources[current_node_id]
stage_devices.append((rank, node_resource, current_device_id))
if distributed_init_method is None:
ip = node_resource.split("node:")[-1]
port = get_open_port()
distributed_init_method = f"tcp://{ip}:{port}"
rank += 1
current_device_id += 1
if current_device_id >= num_devices_per_node:
current_node_id += 1
current_device_id = 0
all_stage_devices.append(stage_devices)
return distributed_init_method, all_stage_devices
"""A GPU worker class.""" """A GPU worker class."""
from typing import Dict, List, Tuple import os
from typing import Dict, List, Tuple, Optional
import torch import torch
import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
...@@ -27,8 +29,8 @@ class Worker: ...@@ -27,8 +29,8 @@ class Worker:
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
rank: int, rank: Optional[int] = None,
distributed_init_method: str, distributed_init_method: Optional[str] = None,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
...@@ -36,27 +38,39 @@ class Worker: ...@@ -36,27 +38,39 @@ class Worker:
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self.cache_config = None
self.block_size = None
self.cache_engine = None
self.cache_events = None
self.gpu_cache = None
def init_model(self):
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
# Env vars will be set by Ray.
self.rank = self.rank if self.rank is not None else int(
os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.device = torch.device(f"cuda:{local_rank}")
if self.rank < 0:
raise ValueError("Invalid or unspecified rank.")
torch.cuda.set_device(self.device)
# Initialize the distributed environment. # Initialize the distributed environment.
_init_distributed_environment(parallel_config, rank, _init_distributed_environment(self.parallel_config, self.rank,
distributed_init_method) self.distributed_init_method)
# Initialize the model. # Initialize the model.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
self.model = get_model(model_config) self.model = get_model(self.model_config)
initialize_all_reduce_launcher( initialize_all_reduce_launcher(
self.scheduler_config.max_num_batched_tokens, self.scheduler_config.max_num_batched_tokens,
self.model_config.get_hidden_size(), self.model_config.get_hidden_size(),
self.model_config.dtype, self.model_config.dtype,
) )
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self.cache_config = None
self.block_size = None
self.cache_engine = None
self.cache_events = None
self.gpu_cache = None
@torch.inference_mode() @torch.inference_mode()
def profile_num_available_blocks( def profile_num_available_blocks(
self, self,
...@@ -294,15 +308,28 @@ class Worker: ...@@ -294,15 +308,28 @@ class Worker:
def _init_distributed_environment( def _init_distributed_environment(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: Optional[str] = None,
) -> None: ) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
torch.distributed.init_process_group( if torch.distributed.is_initialized():
backend="nccl", torch_world_size = torch.distributed.get_world_size()
world_size=parallel_config.world_size, if torch_world_size != parallel_config.world_size:
rank=rank, raise RuntimeError(
init_method=distributed_init_method, "torch.distributed is already initialized but the torch world "
) "size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
torch.distributed.init_process_group(
backend="nccl",
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
# A small all_reduce for warmup. # A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda()) torch.distributed.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(parallel_config.tensor_parallel_size, initialize_model_parallel(parallel_config.tensor_parallel_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment