Unverified Commit eb24dc4a authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[v1] torchrun compatibility (#13642)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 9bebc951
...@@ -503,6 +503,7 @@ steps: ...@@ -503,6 +503,7 @@ steps:
- entrypoints/llm/test_collective_rpc.py - entrypoints/llm/test_collective_rpc.py
commands: commands:
- pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm/test_collective_rpc.py
- VLLM_USE_V1=1 torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py - torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
- pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py - pytest -v -s ./compile/test_wrapper.py
......
...@@ -48,6 +48,12 @@ test_consistent_across_ranks( ...@@ -48,6 +48,12 @@ test_consistent_across_ranks(
test_consistent_across_ranks( test_consistent_across_ranks(
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
# make sure we can access the model parameters from the calling process
# of the `LLM` instance.
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
model.parameters())
test_consistent_across_ranks(len(params))
# all ranks should have the same outputs # all ranks should have the same outputs
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
......
...@@ -5,6 +5,7 @@ import threading ...@@ -5,6 +5,7 @@ import threading
import time import time
import uuid import uuid
from concurrent.futures import Future from concurrent.futures import Future
from typing import List
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -211,8 +212,9 @@ def test_engine_core_concurrent_batches(monkeypatch): ...@@ -211,8 +212,9 @@ def test_engine_core_concurrent_batches(monkeypatch):
class DummyExecutor(UniProcExecutor): class DummyExecutor(UniProcExecutor):
def initialize(self, kv_cache_config: KVCacheConfig) -> None: def initialize_from_config(
super().initialize(kv_cache_config) self, kv_cache_configs: List[KVCacheConfig]) -> None:
super().initialize_from_config(kv_cache_configs)
# This executor actually can only run 1 batch at a time # This executor actually can only run 1 batch at a time
self.semaphore = threading.Semaphore(1) self.semaphore = threading.Semaphore(1)
......
...@@ -1407,6 +1407,11 @@ class ParallelConfig: ...@@ -1407,6 +1407,11 @@ class ParallelConfig:
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
self.world_size_across_dp = self.world_size * self.data_parallel_size self.world_size_across_dp = self.world_size * self.data_parallel_size
if self.distributed_executor_backend == "external_launcher":
import os
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
logger.info("Disabling V1 multiprocessing for external launcher.")
ray_only_devices = ["tpu"] ray_only_devices = ["tpu"]
from vllm.platforms import current_platform from vllm.platforms import current_platform
if (current_platform.device_type in ray_only_devices if (current_platform.device_type in ray_only_devices
......
...@@ -541,7 +541,7 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -541,7 +541,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
# and the TP group executes in SPMD fashion. # and the TP group executes in SPMD fashion.
if self.use_v1: if self.use_v1:
outputs = [ outputs = [
worker.execute_model. worker.execute_model_ray.
bind( # type: ignore[attr-defined] bind( # type: ignore[attr-defined]
outputs[i]) for i, worker in enumerate(tp_group) outputs[i]) for i, worker in enumerate(tp_group)
] ]
......
...@@ -112,10 +112,12 @@ try: ...@@ -112,10 +112,12 @@ try:
torch.cuda.set_device(self.worker.device) torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True self.compiled_dag_cuda_device_set = True
def execute_model( def execute_model_ray(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> "ModelRunnerOutput": ) -> "ModelRunnerOutput":
# this method is used to compile ray CG,
# and it needs a special logic of self.setup_device_if_necessary()
self.setup_device_if_necessary() self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized" assert self.worker is not None, "Worker is not initialized"
if isinstance(scheduler_output, tuple): if isinstance(scheduler_output, tuple):
......
...@@ -93,9 +93,10 @@ class ExecutorWithExternalLauncher(UniProcExecutor): ...@@ -93,9 +93,10 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
("ExecutorWithExternalLauncher needs deterministic " ("ExecutorWithExternalLauncher needs deterministic "
"execution, so it" "execution, so it"
"does not support delay_factor in scheduling") "does not support delay_factor in scheduling")
assert not envs.VLLM_USE_V1, \ if envs.VLLM_USE_V1:
("V1 architecture cannot guarantee deterministic execution, " assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \
"so it is not supported in ExecutorWithExternalLauncher.") ("To get deterministic execution in V1, "
"please set VLLM_ENABLE_V1_MULTIPROCESSING=0")
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
rpc_rank=0) rpc_rank=0)
# engines are launched in torchrun-compatible launchers # engines are launched in torchrun-compatible launchers
......
...@@ -110,7 +110,7 @@ class EngineCore: ...@@ -110,7 +110,7 @@ class EngineCore:
num_cpu_blocks = 0 num_cpu_blocks = 0
# Initialize kv cache and warmup the execution # Initialize kv cache and warmup the execution
self.model_executor.initialize(kv_cache_configs) self.model_executor.initialize_from_config(kv_cache_configs)
elapsed = time.time() - start elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, " logger.info(("init engine (profile, create kv cache, "
......
...@@ -4,10 +4,10 @@ from typing import Dict, List, Mapping, Optional, Type, Union ...@@ -4,10 +4,10 @@ from typing import Dict, List, Mapping, Optional, Type, Union
from typing_extensions import TypeVar from typing_extensions import TypeVar
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig from vllm.config import ParallelConfig, VllmConfig
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.metrics_types import StatLoggerBase
from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -44,6 +44,7 @@ class LLMEngine: ...@@ -44,6 +44,7 @@ class LLMEngine:
use_cached_outputs: bool = False, use_cached_outputs: bool = False,
multiprocess_mode: bool = False, multiprocess_mode: bool = False,
) -> None: ) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
...@@ -83,6 +84,10 @@ class LLMEngine: ...@@ -83,6 +84,10 @@ class LLMEngine:
log_stats=False, # FIXME: implement log_stats=False, # FIXME: implement
) )
if not multiprocess_mode:
# for v0 compatibility
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
@classmethod @classmethod
def from_engine_args( def from_engine_args(
cls, cls,
...@@ -97,7 +102,7 @@ class LLMEngine: ...@@ -97,7 +102,7 @@ class LLMEngine:
vllm_config = engine_args.create_engine_config(usage_context) vllm_config = engine_args.create_engine_config(usage_context)
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
if VLLM_ENABLE_V1_MULTIPROCESSING: if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
logger.debug("Enabling multiprocessing for LLMEngine.") logger.debug("Enabling multiprocessing for LLMEngine.")
enable_multiprocessing = True enable_multiprocessing = True
......
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
from concurrent.futures import Future from concurrent.futures import Future
from typing import List, Type, Union from typing import List, Type, Union
import torch
import torch.distributed as dist
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.executor.uniproc_executor import ( # noqa from vllm.executor.uniproc_executor import ( # noqa
...@@ -49,12 +52,14 @@ class Executor(ExecutorBase): ...@@ -49,12 +52,14 @@ class Executor(ExecutorBase):
f"{distributed_executor_backend}") f"{distributed_executor_backend}")
return executor_class return executor_class
def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None: def initialize_from_config(self,
kv_cache_configs: List[KVCacheConfig]) -> None:
""" """
Initialize the KV caches and begin the model execution loop of the Initialize the KV caches and begin the model execution loop of the
underlying workers. underlying workers.
""" """
self.collective_rpc("initialize_cache", args=(kv_cache_configs, )) self.collective_rpc("initialize_from_config",
args=(kv_cache_configs, ))
self.collective_rpc("compile_or_warm_up_model") self.collective_rpc("compile_or_warm_up_model")
def determine_available_memory(self) -> int: # in bytes def determine_available_memory(self) -> int: # in bytes
...@@ -89,4 +94,13 @@ class UniProcExecutor(UniProcExecutorV0, Executor): ...@@ -89,4 +94,13 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor): class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
pass
def determine_available_memory(self) -> int: # in bytes
# same as determine_num_available_blocks in v0,
# we need to get the min across all ranks.
memory = super().determine_available_memory()
from vllm.distributed.parallel_state import get_world_group
cpu_group = get_world_group().cpu_group
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
return memory_tensor.item()
...@@ -216,9 +216,10 @@ class WorkerProc: ...@@ -216,9 +216,10 @@ class WorkerProc:
"local_rank": local_rank, "local_rank": local_rank,
"rank": rank, "rank": rank,
"distributed_init_method": distributed_init_method, "distributed_init_method": distributed_init_method,
"is_driver_worker": rank == 0,
} }
wrapper.init_worker(all_kwargs) wrapper.init_worker(all_kwargs)
self.worker = wrapper.worker self.worker = wrapper
pid = os.getpid() pid = os.getpid()
_add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid) _add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid)
...@@ -239,7 +240,7 @@ class WorkerProc: ...@@ -239,7 +240,7 @@ class WorkerProc:
ready_socket.send_string(WorkerProc.READY_STR) ready_socket.send_string(WorkerProc.READY_STR)
ready_socket.send(payload) ready_socket.send(payload)
wrapper.init_device() self.worker.init_device()
self.worker.load_model() self.worker.load_model()
@staticmethod @staticmethod
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
"""A GPU worker class.""" """A GPU worker class."""
import gc import gc
import os import os
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
import torch.distributed import torch.distributed
...@@ -185,9 +185,8 @@ class Worker(WorkerBase): ...@@ -185,9 +185,8 @@ class Worker(WorkerBase):
def get_kv_cache_spec(self) -> KVCacheSpec: def get_kv_cache_spec(self) -> KVCacheSpec:
return self.model_runner.get_kv_cache_spec() return self.model_runner.get_kv_cache_spec()
def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None: def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config.""" """Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config = kv_cache_configs[self.rank]
if self.vllm_config.model_config.enable_sleep_mode: if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache") context = allocator.use_memory_pool(tag="kv_cache")
...@@ -225,7 +224,7 @@ class Worker(WorkerBase): ...@@ -225,7 +224,7 @@ class Worker(WorkerBase):
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]: ) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output) output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None return output if self.is_driver_worker else None
def profile(self, is_start: bool = True): def profile(self, is_start: bool = True):
if self.profiler is None: if self.profiler is None:
......
...@@ -36,6 +36,7 @@ class TPUWorker: ...@@ -36,6 +36,7 @@ class TPUWorker:
distributed_init_method: str, distributed_init_method: str,
is_driver_worker: bool = False, is_driver_worker: bool = False,
): ):
self.is_driver_worker = is_driver_worker
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
...@@ -151,7 +152,7 @@ class TPUWorker: ...@@ -151,7 +152,7 @@ class TPUWorker:
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]: ) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output) output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None return output if self.is_driver_worker else None
def load_model(self) -> None: def load_model(self) -> None:
self.model_runner.load_model() self.model_runner.load_model()
...@@ -170,9 +171,8 @@ class TPUWorker: ...@@ -170,9 +171,8 @@ class TPUWorker:
def get_kv_cache_spec(self) -> KVCacheSpec: def get_kv_cache_spec(self) -> KVCacheSpec:
return self.model_runner.get_kv_cache_spec() return self.model_runner.get_kv_cache_spec()
def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None: def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config.""" """Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config = kv_cache_configs[self.rank]
self.model_runner.initialize_kv_cache(kv_cache_config) self.model_runner.initialize_kv_cache(kv_cache_config)
def check_health(self) -> None: def check_health(self) -> None:
......
...@@ -567,6 +567,10 @@ class WorkerWrapperBase: ...@@ -567,6 +567,10 @@ class WorkerWrapperBase:
self.worker = worker_class(**kwargs) self.worker = worker_class(**kwargs)
assert self.worker is not None assert self.worker is not None
def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
kv_cache_config = kv_cache_configs[self.rpc_rank]
self.worker.initialize_from_config(kv_cache_config) # type: ignore
def init_device(self): def init_device(self):
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during device initialization # To make vLLM config available during device initialization
...@@ -574,8 +578,11 @@ class WorkerWrapperBase: ...@@ -574,8 +578,11 @@ class WorkerWrapperBase:
def execute_method(self, method: Union[str, bytes], *args, **kwargs): def execute_method(self, method: Union[str, bytes], *args, **kwargs):
try: try:
target = self if self.worker is None else self.worker # method resolution order:
return run_method(target, method, args, kwargs) # if a method is defined in this class, it will be called directly.
# otherwise, since we define `__getattr__` and redirect attribute
# query to `self.worker`, the method will be called on the worker.
return run_method(self, method, args, kwargs)
except Exception as e: except Exception as e:
# if the driver worker also execute methods, # if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray # exceptions in the rest worker may cause deadlock in rpc like ray
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment