Unverified Commit 0229c386 authored by FlorianJoncour's avatar FlorianJoncour Committed by GitHub
Browse files

Better integration with Ray Serve (#1821)


Co-authored-by: default avatarFlorianJoncour <florian@zetta-sys.com>
parent a7b3e330
...@@ -7,7 +7,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, ...@@ -7,7 +7,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -162,12 +162,12 @@ class LLMEngine: ...@@ -162,12 +162,12 @@ class LLMEngine:
continue continue
worker = ray.remote( worker = ray.remote(
num_cpus=0, num_cpus=0,
num_gpus=1, num_gpus=self.cache_config.gpu_memory_utilization,
scheduling_strategy=PlacementGroupSchedulingStrategy( scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group, placement_group=placement_group,
placement_group_capture_child_tasks=True), placement_group_capture_child_tasks=True),
**ray_remote_kwargs, **ray_remote_kwargs,
)(RayWorker).remote(self.model_config.trust_remote_code) )(RayWorkerVllm).remote(self.model_config.trust_remote_code)
self.workers.append(worker) self.workers.append(worker)
# Initialize torch distributed process group for the workers. # Initialize torch distributed process group for the workers.
......
...@@ -10,7 +10,7 @@ try: ...@@ -10,7 +10,7 @@ try:
import ray import ray
from ray.air.util.torch_dist import TorchDistributedWorker from ray.air.util.torch_dist import TorchDistributedWorker
class RayWorker(TorchDistributedWorker): class RayWorkerVllm(TorchDistributedWorker):
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be """Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
...@@ -36,7 +36,7 @@ except ImportError as e: ...@@ -36,7 +36,7 @@ except ImportError as e:
"`pip install ray pandas pyarrow`.") "`pip install ray pandas pyarrow`.")
ray = None ray = None
TorchDistributedWorker = None TorchDistributedWorker = None
RayWorker = None RayWorkerVllm = None
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment