rlhf_colocate.py 7.17 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
Demonstrates how to co-locate a vLLM inference worker and training
actors on the same set of GPUs for reinforcement learning from human feedback
(RLHF) workloads.

Ray serves as the distributed execution framework in this example. Ray
placement groups allocate both training actors and vLLM workers to the
same GPU bundles, enabling fast, in-GPU communication between the two
components.

The script shows how to do the following:

* Configure environment variables (`VLLM_RAY_PER_WORKER_GPUS` and
  `VLLM_RAY_BUNDLE_INDICES`) so that vLLM workers land on the desired
  devices.
* Exchange tensors between processes by means of CUDA inter-process
  communication (IPC). CUDA IPC sidesteps NCCL limitations that occur
  when multiple processes share a single GPU.

Note that this example assumes a single-node cluster with four GPUs, but Ray
supports multi-node clusters. vLLM expects exclusive use of the GPUs during
its initialization for memory profiling. Residual GPU activity interferes
with vLLM memory profiling and causes unexpected behavior.

Learn more about Ray placement groups:
https://docs.ray.io/en/latest/placement-groups.html
29
"""
30

31
32
33
import os

import ray
34
import torch
35
36
37
38
39
40
41
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from vllm import LLM


class MyLLM(LLM):
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    """Configure the vLLM worker for Ray placement group execution.

    The constructor sets environment variables that allow multiple vLLM
    workers to share a single physical GPU and that encode the bundle
    indices assigned by the placement group.

    Args:
        *args: Positional arguments forwarded to `vllm.LLM`.
        bundle_indices (list[int]): Placement-group bundle indices
            assigned to this worker.
        **kwargs: Keyword arguments forwarded to `vllm.LLM`.
    """

    def __init__(self, *args, bundle_indices: list[int], **kwargs):
        # Prevent Ray from manipulating the top-level CUDA_VISIBLE_DEVICES variable
        # so that vLLM can its own device placement inside the worker.
58
        os.environ.pop("CUDA_VISIBLE_DEVICES", None)
59
        # Each worker uses 0.4 GPU so that two instances fit on the same GPUs.
60
        os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
61
        os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
62
63
64
65
66
        print(f"creating LLM with bundle_indices={bundle_indices}")
        super().__init__(*args, **kwargs)


class RayTrainingActor:
67
68
69
70
71
72
73
    """Training actor that hosts a Facebook OPT-125M model from Hugging Face.

    The model is loaded onto the first GPU assigned to this actor, and expose
    the CUDA IPC handles so that colocated vLLM workers can map tensors
    directly.
    """

74
    def __init__(self):
75
        # Ray sets CUDA_VISIBLE_DEVICES to the GPUs assigned to this actor.
76
        from transformers import AutoModelForCausalLM
77

78
79
        self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
        self.model.to("cuda:0")
80
        # Zero out all the parameters.
81
82
83
        for name, p in self.model.named_parameters():
            p.data.zero_()
        torch.cuda.synchronize()
84
85
        # The argument for `get_device_uuid` is the index of the GPU in the
        # list of visible devices.
86
        from vllm.platforms import current_platform
87

88
89
90
91
92
93
94
        self.device_uuid = current_platform.get_device_uuid(0)

    def report_device_id(self) -> str:
        return self.device_uuid

    def get_weight_ipc_handles(self):
        from torch.multiprocessing.reductions import reduce_tensor
95

96
97
        data = {}
        for name, p in self.model.named_parameters():
98
99
100
            # A training actor might hold only a subset of the weights and may
            # need to gather weights from other actors. For demonstration
            # purposes, each training actor owns the full weight set.
101
102
            data[name] = reduce_tensor(p.detach())
        return {self.device_uuid: data}
103
104


105
106
# Ray manages four GPUs.

107
108
109
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
ray.init()

110
111
112
113
114
# Co-locate vLLM instances and training actors on the same set of GPUs:
#   * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0
#     (tensor parallelism = 2).
#   * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1
#     (tensor parallelism = 2).
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
ray.get(pg.ready())
print(f"placement group has bundles {pg.bundle_specs=}")

training_actors = []
training_actor_device_ids = []
inference_engines = []
inference_engine_device_ids = []

for bundle_index in [0, 1, 2, 3]:
    training_actor = ray.remote(
        num_cpus=0,
        num_gpus=0.4,
        scheduling_strategy=PlacementGroupSchedulingStrategy(
            placement_group=pg,
            placement_group_capture_child_tasks=True,
            placement_group_bundle_index=bundle_index,
        ),
    )(RayTrainingActor).remote()
    training_actors.append(training_actor)
136
137

for bundle_index, training_actor in enumerate(training_actors):
138
139
140
141
    device_id = ray.get(training_actor.report_device_id.remote())
    print(f"training actor {bundle_index} is on {device_id}")
    training_actor_device_ids.append(device_id)

142
for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
143
144
    # Use the following syntax instead of the @ray.remote decorator so that
    # the placement group is customized for each bundle.
145
146
147
148
149
150
151
152
153
154
    llm = ray.remote(
        num_cpus=0,
        num_gpus=0,
        scheduling_strategy=PlacementGroupSchedulingStrategy(
            placement_group=pg,
            placement_group_capture_child_tasks=True,
        ),
    )(MyLLM).remote(
        model="facebook/opt-125m",
        enforce_eager=True,
155
        worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
156
157
158
159
160
161
        tensor_parallel_size=2,
        distributed_executor_backend="ray",
        gpu_memory_utilization=0.4,
        bundle_indices=bundle_indices,
    )
    inference_engines.append(llm)
162
163
    # Do not call any method on the inference engine at this point; the call
    # blocks until the vLLM instance finishes initialization.
164
165
166

for i, llm in enumerate(inference_engines):
    inference_engine_device_ids.append(
167
168
        ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
    )
169
170
    print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")

171
172
# Verify placement: the first two training actors share the same GPUs as
# the first inference engine.
173
assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
174
175
# Verify placement: the last two training actors share the same GPUs as
# the second inference engine.
176
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
177

178
print("Gather all the IPC handles from the training actors.")
179
180
181
182
ipc_handles = {}
for actor in training_actors:
    ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote()))

183
print("Update the weights of the inference engines.")
184
185
for llm in inference_engines:
    ray.get(
186
187
188
189
        llm.collective_rpc.remote(
            "update_weights_from_ipc_handles", args=(ipc_handles,)
        )
    )
190
print("Check if the weights are updated.")
191
for llm in inference_engines:
192
    assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))