Unverified Commit 151b08e0 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[RLHF] use worker_extension_cls for compatibility with V0 and V1 (#14185)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 81b2f4a4
......@@ -145,8 +145,10 @@ steps:
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
- python3 ../examples/offline_inference/rlhf.py
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py
- pushd ../examples/offline_inference
- python3 rlhf.py
- RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- popd
- label: Metrics, Tracing Test # 10min
num_gpus: 2
......
......@@ -18,72 +18,11 @@ import ray
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from rlhf_utils import stateless_init_process_group
from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.utils import get_ip, get_open_port
from vllm.worker.worker import Worker
def stateless_init_process_group(master_address, master_port, rank, world_size,
device):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
pg = StatelessProcessGroup.create(host=master_address,
port=master_port,
rank=rank,
world_size=world_size)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl
class MyWorker(Worker):
"""
The `MyWorker` class inherits from `Worker` to provide custom functions.
For simplicity, we define the `MyWorker` class in this self-contained
script. Normally, we should define the `MyWorker` class in a separate
file and pass the qualified name of the class to the `worker_cls`
parameter.
"""
def init_weight_update_group(self, master_address, master_port,
rank_offset, world_size):
from vllm.distributed.parallel_state import get_world_group
rank = get_world_group().rank + rank_offset
self.model_update_group = stateless_init_process_group(
master_address,
master_port,
rank,
world_size,
self.device,
)
def update_weight(self, name, dtype, shape):
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(weight,
src=0,
stream=torch.cuda.current_stream())
self.model_runner.model.load_weights(weights=[(name, weight)])
del weight
def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated
class MyLLM(LLM):
......@@ -129,7 +68,7 @@ llm = ray.remote(
)(MyLLM).remote(
model="facebook/opt-125m",
enforce_eager=True,
worker_cls=MyWorker,
worker_extension_cls="rlhf_utils.WorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
)
......@@ -159,6 +98,7 @@ master_port = get_open_port()
handle = llm.collective_rpc.remote("init_weight_update_group",
args=(master_address, master_port, 1, 3))
model_update_group = stateless_init_process_group(master_address, master_port,
0, 3, torch.device("cuda:0"))
ray.get(handle)
......
......@@ -17,40 +17,6 @@ from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from vllm import LLM
from vllm.worker.worker import Worker
class MyWorker(Worker):
def report_device_id(self) -> str:
from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid
def update_weights_from_ipc_handles(self, ipc_handles):
handles = ipc_handles[self.device_uuid]
device_id = self.device.index
weights = []
for name, handle in handles.items():
func, args = handle
list_args = list(args)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()
def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated
class MyLLM(LLM):
......@@ -150,7 +116,7 @@ for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]):
)(MyLLM).remote(
model="facebook/opt-125m",
enforce_eager=True,
worker_cls=MyWorker,
worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
gpu_memory_utilization=0.4,
......
# SPDX-License-Identifier: Apache-2.0
import torch
def stateless_init_process_group(master_address, master_port, rank, world_size,
device):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
pg = StatelessProcessGroup.create(host=master_address,
port=master_port,
rank=rank,
world_size=world_size)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl
class WorkerExtension:
"""
The class for vLLM's worker to inherit from.
By defining an extension class, the code can work no matter what is
the underlying worker class. This way, the code can be compatible
with both vLLM V0 and V1.
NOTE: we define this class in a separate module, and the main module
should pass the full qualified name as `worker_extension_cls` argument.
"""
def init_weight_update_group(self, master_address, master_port,
rank_offset, world_size):
from vllm.distributed.parallel_state import get_world_group
rank = get_world_group().rank + rank_offset
self.model_update_group = stateless_init_process_group(
master_address,
master_port,
rank,
world_size,
self.device,
)
def update_weight(self, name, dtype, shape):
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(weight,
src=0,
stream=torch.cuda.current_stream())
self.model_runner.model.load_weights(weights=[(name, weight)])
del weight
def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated
class ColocateWorkerExtension:
"""
The class for vLLM's worker to inherit from, in the colocate setting.
By defining an extension class, the code can work no matter what is
the underlying worker class. This way, the code can be compatible
with both vLLM V0 and V1.
NOTE: we define this class in a separate module, and the main module
should pass the full qualified name as `worker_extension_cls` argument.
"""
def report_device_id(self) -> str:
from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid
def update_weights_from_ipc_handles(self, ipc_handles):
handles = ipc_handles[self.device_uuid]
device_id = self.device.index
weights = []
for name, handle in handles.items():
func, args = handle
list_args = list(args)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()
def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
return weights_updated
......@@ -1366,6 +1366,7 @@ class ParallelConfig:
# will be determined based on the platform.
worker_cls: str = "auto"
sd_worker_cls: str = "auto"
worker_extension_cls: str = ""
# world_size is TPxPP, it affects the number of workers we create.
world_size: int = field(init=False)
......@@ -1523,6 +1524,9 @@ class ParallelConfig:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")
assert isinstance(self.worker_extension_cls, str), (
"worker_extension_cls must be a string (qualified class name).")
@dataclass
class SchedulerConfig:
......
......@@ -202,6 +202,7 @@ class EngineArgs:
override_pooler_config: Optional[PoolerConfig] = None
compilation_config: Optional[CompilationConfig] = None
worker_cls: str = "auto"
worker_extension_cls: str = ""
kv_transfer_config: Optional[KVTransferConfig] = None
......@@ -1015,6 +1016,13 @@ class EngineArgs:
type=str,
default="auto",
help='The worker class to use for distributed execution.')
parser.add_argument(
'--worker-extension-cls',
type=str,
default="",
help='The worker extension class on top of the worker cls, '
'it is useful if you just want to add new functions to the worker '
'class without changing the existing functions.')
parser.add_argument(
"--generation-config",
type=nullable_str,
......@@ -1209,6 +1217,7 @@ class EngineArgs:
ray_workers_use_nsight=self.ray_workers_use_nsight,
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
)
max_model_len = model_config.max_model_len
......
......@@ -558,10 +558,37 @@ class WorkerWrapperBase:
worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls)
else:
logger.warning(
"passing worker_cls as a class object is strongly deprecated,"
" as the serialization of class objects can be tricky and"
" error-prone. To be safe, please keep the class in a separate"
" module and pass the qualified name of the class as a string."
)
assert isinstance(self.vllm_config.parallel_config.worker_cls,
bytes)
worker_class = cloudpickle.loads(
self.vllm_config.parallel_config.worker_cls)
if self.vllm_config.parallel_config.worker_extension_cls:
worker_extension_cls = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_extension_cls)
extended_calls = []
if worker_extension_cls not in worker_class.__bases__:
# check any conflicts between worker and worker_extension_cls
for attr in dir(worker_extension_cls):
if attr.startswith("__"):
continue
assert not hasattr(worker_class, attr), (
f"Worker class {worker_class} already has an attribute"
f" {attr}, which conflicts with the worker"
f" extension class {worker_extension_cls}.")
if callable(getattr(worker_extension_cls, attr)):
extended_calls.append(attr)
# dynamically inherit the worker extension class
worker_class.__bases__ = worker_class.__bases__ + (
worker_extension_cls, )
logger.info(
"Injected %s into %s for extended collective_rpc calls %s",
worker_extension_cls, worker_class, extended_calls)
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment