rlhf_colocate.py 9.35 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
Demonstrates how to co-locate a vLLM inference worker and training
actors on the same set of GPUs for reinforcement learning from human feedback
(RLHF) workloads.

Ray serves as the distributed execution framework in this example. Ray
placement groups allocate both training actors and vLLM workers to the
same GPU bundles, enabling fast, in-GPU communication between the two
components.

The script shows how to do the following:

* Configure environment variables (`VLLM_RAY_PER_WORKER_GPUS` and
  `VLLM_RAY_BUNDLE_INDICES`) so that vLLM workers land on the desired
  devices.
* Exchange tensors between processes by means of CUDA inter-process
  communication (IPC). CUDA IPC sidesteps NCCL limitations that occur
  when multiple processes share a single GPU.

Note that this example assumes a single-node cluster with four GPUs, but Ray
supports multi-node clusters. vLLM expects exclusive use of the GPUs during
its initialization for memory profiling. Residual GPU activity interferes
with vLLM memory profiling and causes unexpected behavior.

Learn more about Ray placement groups:
https://docs.ray.io/en/latest/placement-groups.html
29
"""
30

31
import gc
32
import os
33
import sys
34
35

import ray
36
import torch
37
import zmq
38
39
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
40
from torch.multiprocessing.reductions import reduce_tensor
41
42
43

from vllm import LLM

44
45
46
47
if torch.version.hip is not None:
    print("Skipping test for ROCm. Ray is unsupported on vLLM ROCm.")
    sys.exit(0)

48
49

class MyLLM(LLM):
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    """Configure the vLLM worker for Ray placement group execution.

    The constructor sets environment variables that allow multiple vLLM
    workers to share a single physical GPU and that encode the bundle
    indices assigned by the placement group.

    Args:
        *args: Positional arguments forwarded to `vllm.LLM`.
        bundle_indices (list[int]): Placement-group bundle indices
            assigned to this worker.
        **kwargs: Keyword arguments forwarded to `vllm.LLM`.
    """

    def __init__(self, *args, bundle_indices: list[int], **kwargs):
        # Prevent Ray from manipulating the top-level CUDA_VISIBLE_DEVICES variable
        # so that vLLM can its own device placement inside the worker.
66
        os.environ.pop("CUDA_VISIBLE_DEVICES", None)
67
        # Each worker uses 0.4 GPU so that two instances fit on the same GPUs.
68
        os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
69
        os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
70
71
72
73
74
        print(f"creating LLM with bundle_indices={bundle_indices}")
        super().__init__(*args, **kwargs)


class RayTrainingActor:
75
76
77
78
79
80
81
    """Training actor that hosts a Facebook OPT-125M model from Hugging Face.

    The model is loaded onto the first GPU assigned to this actor, and expose
    the CUDA IPC handles so that colocated vLLM workers can map tensors
    directly.
    """

82
    def __init__(self):
83
        # Ray sets CUDA_VISIBLE_DEVICES to the GPUs assigned to this actor.
84
        from transformers import AutoModelForCausalLM
85

86
87
        self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
        self.model.to("cuda:0")
88
        # Zero out all the parameters.
89
90
        for name, p in self.model.named_parameters():
            p.data.zero_()
91
        torch.accelerator.synchronize()
92
93
        # The argument for `get_device_uuid` is the index of the GPU in the
        # list of visible devices.
94
        from vllm.platforms import current_platform
95

96
        self.device_uuid = current_platform.get_device_uuid(0)
97
98
99
        self.zmq_context = zmq.Context()
        self.zmq_address_counter = 0
        self.zmq_handle = None
100
101
102
103

    def report_device_id(self) -> str:
        return self.device_uuid

104
105
106
107
108
    def get_zmq_handles(self) -> dict[str, str]:
        suffix = f"{self.device_uuid}-{self.zmq_address_counter}"
        self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{suffix}.sock"
        self.zmq_address_counter += 1
        return {self.device_uuid: self.zmq_handle}
109

110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    def update_weights(self):
        # align size to avoid misaligned address
        align_size = 256

        def get_size(p: torch.Tensor) -> int:
            return (p.nbytes + align_size - 1) // align_size * align_size

        named_parameters: dict[str, torch.nn.Parameter] = dict(
            self.model.named_parameters()
        )
        max_tensor_size = max(get_size(p) for p in named_parameters.values())
        # use max_tensor_size * 2 as buffer size
        buffer = torch.empty(max_tensor_size * 2, dtype=torch.uint8, device="cuda:0")
        s = self.zmq_context.socket(zmq.REQ)
        s.bind(self.zmq_handle)
        handle = reduce_tensor(buffer)

        offset = 0
        buckets: list[tuple[list[dict], list[torch.Tensor]]] = []
        named_tensors: list[dict] = []
        real_tensors: list[torch.Tensor] = []
        for name, p in named_parameters.items():
            size = get_size(p)
            if offset + size > buffer.numel():
                buckets.append((named_tensors, real_tensors))
                named_tensors, real_tensors = [], []
                offset = 0
            # assume tensors are contiguous
            named_tensors.append(
                {"name": name, "dtype": p.dtype, "shape": p.shape, "offset": offset}
            )
            real_tensors.append(p)
            offset += size
        if named_tensors:
            buckets.append((named_tensors, real_tensors))
        s.send_pyobj(handle)
        s.recv()
        for named_tensors, real_tensors in buckets:
            offset = 0
            for p in real_tensors:
                buffer[offset : offset + p.nbytes].data.copy_(
                    p.data.view(-1).view(dtype=torch.uint8), non_blocking=True
                )
                offset += get_size(p)
154
            torch.accelerator.synchronize()
155
156
157
158
159
160
161
            s.send_pyobj(named_tensors)
            s.recv()
        s.send_pyobj(None)
        s.recv()
        s.close()
        del buffer
        gc.collect()
162
        torch.accelerator.empty_cache()
163
164


165
166
# Ray manages four GPUs.

167
168
169
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
ray.init()

170
171
172
173
174
# Co-locate vLLM instances and training actors on the same set of GPUs:
#   * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0
#     (tensor parallelism = 2).
#   * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1
#     (tensor parallelism = 2).
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

pg = placement_group([{"GPU": 1, "CPU": 0}] * 4)
ray.get(pg.ready())
print(f"placement group has bundles {pg.bundle_specs=}")

training_actors = []
training_actor_device_ids = []
inference_engines = []
inference_engine_device_ids = []

for bundle_index in [0, 1, 2, 3]:
    training_actor = ray.remote(
        num_cpus=0,
        num_gpus=0.4,
        scheduling_strategy=PlacementGroupSchedulingStrategy(
            placement_group=pg,
            placement_group_capture_child_tasks=True,
            placement_group_bundle_index=bundle_index,
        ),
    )(RayTrainingActor).remote()
    training_actors.append(training_actor)
196
197

for bundle_index, training_actor in enumerate(training_actors):
198
199
200
201
    device_id = ray.get(training_actor.report_device_id.remote())
    print(f"training actor {bundle_index} is on {device_id}")
    training_actor_device_ids.append(device_id)

202
for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
203
204
    # Use the following syntax instead of the @ray.remote decorator so that
    # the placement group is customized for each bundle.
205
206
207
208
209
210
211
212
213
214
    llm = ray.remote(
        num_cpus=0,
        num_gpus=0,
        scheduling_strategy=PlacementGroupSchedulingStrategy(
            placement_group=pg,
            placement_group_capture_child_tasks=True,
        ),
    )(MyLLM).remote(
        model="facebook/opt-125m",
        enforce_eager=True,
215
        worker_extension_cls="rlhf_utils.ColocateWorkerExtension",
216
217
218
219
220
221
        tensor_parallel_size=2,
        distributed_executor_backend="ray",
        gpu_memory_utilization=0.4,
        bundle_indices=bundle_indices,
    )
    inference_engines.append(llm)
222
223
    # Do not call any method on the inference engine at this point; the call
    # blocks until the vLLM instance finishes initialization.
224
225
226

for i, llm in enumerate(inference_engines):
    inference_engine_device_ids.append(
227
228
        ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
    )
229
230
    print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")

231
232
# Verify placement: the first two training actors share the same GPUs as
# the first inference engine.
233
assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
234
235
# Verify placement: the last two training actors share the same GPUs as
# the second inference engine.
236
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]
237

238
239
print("Gather all the ZMQ handles from the training actors.")
zmq_handles = {}
240
for actor in training_actors:
241
242
243
    zmq_handles.update(ray.get(actor.get_zmq_handles.remote()))

print(f"ZMQ handles: {zmq_handles}")
244

245
print("Update the weights of the inference engines.")
246
247
248
249
250
251
252
253
ray.get(
    [actor.update_weights.remote() for actor in training_actors]
    + [
        llm.collective_rpc.remote("update_weights_from_ipc", args=(zmq_handles,))
        for llm in inference_engines
    ]
)

254
print("Check if the weights are updated.")
255
for llm in inference_engines:
256
    assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))