rlhf_async_new_apis.py 14.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates async reinforcement learning using vLLM and Ray,
with native weight syncing APIs at engine instance.

The script separates training and inference workloads onto distinct GPUs
so that Ray can manage process placement and inter-process communication.
A Hugging Face Transformer model occupies one GPU for training, whereas a
2x tensor-parallel vLLM inference engine occupies two GPUs.

The example performs the following steps:
* Load the training model on one gpu (scheduled via ray)
* Initialize the inference model with dummy weights across
  two gpus using vLLM's tensor parallelism and Ray placement groups.
* Generate gibberish from a list of prompts using the randomly initialized
  inference engine.
* Pause generation once generation completes for one sequence
* Update the weights of the training model and broadcast the updated weights
  to the inference engine by using a Ray collective RPC group.
* Resume generation and print out the results

This example assumes a single-node cluster with three GPUs, but Ray
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
"""

29
import asyncio
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import uuid
from dataclasses import asdict

import ray
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

import vllm
from vllm import SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
from vllm.distributed.weight_transfer.nccl_engine import (
45
    NCCLTrainerSendWeightsArgs,
46
47
48
49
    NCCLWeightTransferEngine,
    NCCLWeightTransferInitInfo,
    NCCLWeightTransferUpdateInfo,
)
50
from vllm.platforms import current_platform
51
52
53
from vllm.utils.network_utils import get_ip, get_open_port
from vllm.v1.executor import Executor

54
55
56
MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base"
MODEL_NAME_V2 = "Qwen/Qwen3-1.7B"
PAUSE_TOKEN_THRESHOLD = 10
57
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "FLASH_ATTN"
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72


class MyLLM(vllm.AsyncLLMEngine):
    """Configure the vLLM worker for Ray placement group execution."""

    def __init__(self, **kwargs):
        engine_args = vllm.AsyncEngineArgs(**kwargs)
        vllm_config = engine_args.create_engine_config()
        executor_class = Executor.get_class(vllm_config)
        super().__init__(
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_requests=engine_args.enable_log_requests,
            log_stats=not engine_args.disable_log_stats,
        )
73
74
        self._generation_paused = False
        self._request_pause_flag = False
75

76
    async def do_generate(
77
        self, prompt_token_ids: list[int], sampling_params: vllm.SamplingParams
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    ) -> tuple[vllm.RequestOutput, int]:
        """Generate a single request, setting the request pause flag once the
        token count reaches the threshold.

        Returns (output, pause_token_index). pause_token_index is the number
        of tokens generated before the weight change, or -1 if no pause.
        """
        pause_token_index = -1
        prev_token_count = 0
        async for request_output in self.generate(
            {"prompt_token_ids": prompt_token_ids},
            sampling_params,
            request_id=str(uuid.uuid4()),
        ):
            output = request_output
            cur_token_count = len(output.outputs[0].token_ids)
            if (
                cur_token_count >= PAUSE_TOKEN_THRESHOLD
                and not self._request_pause_flag
97
            ):
98
99
100
101
102
103
104
105
106
107
108
                self._request_pause_flag = True
            if self._generation_paused and pause_token_index == -1:
                pause_token_index = prev_token_count
            prev_token_count = cur_token_count
        return output, pause_token_index

    async def pause_after_n_tokens(self):
        """Wait for any request to set the pause flag, then pause."""
        while not self._request_pause_flag:
            await asyncio.sleep(0)
        await super().pause_generation(mode="keep")
109
        await asyncio.sleep(5)
110
        self._generation_paused = True
111
112
113
114
115
116
117


@ray.remote(num_gpus=1)
class TrainModel:
    """Ray actor that wraps the training model on a dedicated GPU."""

    def __init__(self, model_name: str):
118
119
120
        from vllm.model_executor.layers.batch_invariant import (
            init_batch_invariance,
        )
121
        from vllm.platforms import current_platform
122
123
124
        from vllm.v1.attention.backends.registry import AttentionBackendEnum

        # need to init all env vars for batch invariance which affect nccl ops
125
126
127
128
129
130
        attn_backend = (
            AttentionBackendEnum.TRITON_ATTN
            if current_platform.is_rocm()
            else AttentionBackendEnum.FLASH_ATTN
        )
        init_batch_invariance(attn_backend)
131

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name, dtype=torch.bfloat16
        ).to("cuda:0")
        self.port = get_open_port()
        self.master_address = get_ip()

    def get_master_address_and_port(self):
        return self.master_address, self.port

    def get_weight_metadata(self):
        """Return weight names, dtypes, and shapes for weight transfer."""
        names = []
        dtype_names = []
        shapes = []
        for name, p in self.model.named_parameters():
            names.append(name)
            dtype_names.append(str(p.dtype).split(".")[-1])
            shapes.append(list(p.shape))
        return names, dtype_names, shapes

    def init_weight_transfer_group(self, world_size):
        """Initialize the NCCL process group for weight transfer."""
        self.model_update_group = NCCLWeightTransferEngine.trainer_init(
            dict(
                master_address=self.master_address,
                master_port=self.port,
                world_size=world_size,
            ),
        )

    def broadcast_weights(self, packed: bool = True):
        """Broadcast weights to the inference engine."""
164
        trainer_args = NCCLTrainerSendWeightsArgs(
165
166
167
            group=self.model_update_group,
            packed=packed,
        )
168
169
170
171
        NCCLWeightTransferEngine.trainer_send_weights(
            iterator=self.model.named_parameters(),
            trainer_args=trainer_args,
        )
172

173
174
175
176
177
178
179
180
181
182
183
184
185
    @torch.inference_mode()
    def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]:
        """Greedy-decode max_new_tokens from the given context."""
        input_ids = torch.tensor([token_ids], device="cuda:0")
        output = self.model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=False,
        )
        new_token_ids = output[0, len(token_ids) :].tolist()
        return new_token_ids


186
187
188
189
190
191
192
193
194
195
196
197
198
199
# Build platform-specific env vars for Ray
ray_env_vars = {
    # Prevent Ray from setting CUDA_VISIBLE_DEVICES
    "RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1",
}

if current_platform.is_rocm():
    # For ROCm, BATCH_INVARIANT vllm is not supported
    ray_env_vars["VLLM_ROCM_USE_SKINNY_GEMM"] = "0"
else:
    # Enable batch invariance for deterministic outputs on NVIDIA
    ray_env_vars["VLLM_BATCH_INVARIANT"] = "1"

ray.init(runtime_env={"env_vars": ray_env_vars})
200
201
202

# Launch the training model actor. Ray's resource scheduler will allocate
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
203
train_model = TrainModel.remote(MODEL_NAME_V2)
204

205
206
207
208
209
210
211
212
213
214
215
216
rocm_determinism_kwargs = {}
if current_platform.is_rocm():
    # ROCm: To minimize non-determinism, we set fixed seed, no prefix caching, and
    # sequential request processing (max_num_seqs=1).
    rocm_determinism_kwargs = {
        "seed": 0,
        "enable_prefix_caching": False,
        "max_num_seqs": 1,
    }

# Build platform-specific LLM kwargs
llm_kwargs = dict(
217
    model=MODEL_NAME_V1,
218
    enforce_eager=True,
219
    max_model_len=8192,
220
    distributed_executor_backend="ray",
221
    attention_backend=ATTN_BACKEND,
222
    gpu_memory_utilization=0.75,
223
224
    weight_transfer_config=WeightTransferConfig(backend="nccl"),
)
225
226
227
228
229
230
231
232
233
234
235
llm_kwargs.update(rocm_determinism_kwargs)

# Launch the vLLM inference engine.
# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates
# its own placement groups internally for each DP rank, so we must NOT
# create an outer placement group (it would reserve GPUs and hide them
# from the internal DP resource check).
llm = ray.remote(
    num_cpus=0,
    num_gpus=0,
)(MyLLM).remote(**llm_kwargs)
236

237
PROMPTS = [
238
239
    "The president of the United States is",
    "The capital of France is",
240
241
242
243
244
245
246
247
248
249
250
    "The largest ocean on Earth is",
    "The speed of light in a vacuum is",
    "The chemical formula for water is",
    "The tallest mountain in the world is",
    "The first person to walk on the moon was",
    "The Great Wall of China was built to",
    "Photosynthesis is the process by which",
    "The theory of general relativity was proposed by",
    "The boiling point of water at sea level is",
    "The largest planet in our solar system is",
    "DNA stands for deoxyribonucleic acid and it",
251
252
]

253
254
255
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_V1)
batch_prompt_token_ids = [
    tokenizer.encode(prompt, add_special_tokens=False) for prompt in PROMPTS
256
257
258
259
260
261
262
]


# Set up the communication channel between the training process and the
# inference engine.
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())

263
world_size = 2  # 1 trainer + 1 inference worker
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
inference_handle = llm.init_weight_transfer_engine.remote(
    WeightTransferInitRequest(
        init_info=asdict(
            NCCLWeightTransferInitInfo(
                master_address=master_address,
                master_port=master_port,
                rank_offset=1,
                world_size=world_size,
            )
        )
    )
)

# Initialize weight transfer group on both the training actor and inference engine
train_handle = train_model.init_weight_transfer_group.remote(world_size)
ray.get([train_handle, inference_handle])


282
N_NEW_TOKENS = 100
283

284
285
# Collect weight metadata once
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
286

287
288
289
290
291
292
# ── Phase 1: concurrent requests with weight sync ───────────────────
print(f"\n{'=' * 50}")
print(f"Prompts ({len(PROMPTS)}):")
for p in PROMPTS:
    print(f"  - {p!r}")
print(f"{'=' * 50}")
293

294
295
296
297
298
299
300
301
302
sampling_params = SamplingParams(
    temperature=0, max_tokens=PAUSE_TOKEN_THRESHOLD + N_NEW_TOKENS
)

gen_futures = [
    llm.do_generate.remote(ptids, sampling_params) for ptids in batch_prompt_token_ids
]

ray.get(llm.pause_after_n_tokens.remote())
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319

inference_handle = llm.update_weights.remote(
    WeightTransferUpdateRequest(
        update_info=asdict(
            NCCLWeightTransferUpdateInfo(
                names=names,
                dtype_names=dtype_names,
                shapes=shapes,
                packed=True,
            )
        )
    )
)
train_handle = train_model.broadcast_weights.remote(packed=True)
ray.get([train_handle, inference_handle])

ray.get(llm.resume_generation.remote())
320
321
322
323
324
325
326
327
328
329
330
331
results = ray.get(gen_futures)

for i, (output, pause_idx) in enumerate(results):
    all_token_ids = list(output.outputs[0].token_ids)
    before_text = tokenizer.decode(all_token_ids[:pause_idx])
    after_text = tokenizer.decode(all_token_ids[pause_idx:])
    print(f"\n  Request {i} ({PROMPTS[i]!r}):")
    print(f"    Old weights ({pause_idx} tokens): {before_text!r}")
    n_after = len(all_token_ids) - pause_idx
    print(f"    New weights ({n_after} tokens): {after_text!r}")

# ── Phase 2: validate with a fresh V2 vLLM instance ────────────────
332
333
334
335
336
337
338
339
340
341
342
343
# This validation relies on batch-invariant (deterministic) generation to
# compare outputs from the weight-synced engine against a fresh V2 instance.
# On NVIDIA, batch invariance is fully supported, so we require 100% exact
# token match. On ROCm, batch invariance is not yet fully implemented
# (see https://github.com/vllm-project/vllm/issues/27433 and
# https://github.com/vllm-project/vllm/issues/33123), so residual
# non-determinism (e.g. GEMM accumulation order, missing kernel overrides)
# can cause single-token divergences that don't indicate a weight-sync
# failure. We relax the pass rate to 90% on ROCm to accommodate this; a
# real regression (broken weight transfer) would cause ~0% pass rate, not 90%+.
MIN_PASS_RATE = 1.0 if not current_platform.is_rocm() else 0.9

344
345
print(f"\n{'=' * 50}")
print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance")
346
347
if current_platform.is_rocm():
    print(f"  (ROCm mode: requiring >= {MIN_PASS_RATE:.0%} exact match rate)")
348
349
350
351
352
353
print(f"{'=' * 50}")

ray.get(llm.shutdown.remote())
ray.kill(llm)
ray.kill(train_model)

354
llm_v2_kwargs = dict(
355
356
357
358
359
    model=MODEL_NAME_V2,
    enforce_eager=True,
    max_model_len=8192,
    gpu_memory_utilization=0.75,
    distributed_executor_backend="ray",
360
    attention_backend=ATTN_BACKEND,
361
)
362
363
364
365
366
367
llm_v2_kwargs.update(rocm_determinism_kwargs)

llm_v2 = ray.remote(
    num_cpus=0,
    num_gpus=0,
)(MyLLM).remote(**llm_v2_kwargs)
368
369
370
371
372
373
374
375
376
377
378
379

val_futures = [
    llm_v2.do_generate.remote(
        list(output.prompt_token_ids) + list(output.outputs[0].token_ids)[:pause_idx],
        SamplingParams(
            temperature=0, max_tokens=len(output.outputs[0].token_ids) - pause_idx
        ),
    )
    for output, pause_idx in results
]
val_results = ray.get(val_futures)

380
381
num_pass = 0
num_total = len(results)
382
383
384
385
386
387
for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)):
    expected = list(output.outputs[0].token_ids)[pause_idx:]
    actual = list(val_output.outputs[0].token_ids)
    match = actual == expected

    if match:
388
        num_pass += 1
389
390
391
392
393
394
395
396
397
398
399
400
401
        print(f"  [PASS] {PROMPTS[i]!r}")
    else:
        print(f"  [FAIL] {PROMPTS[i]!r}")
        print(f"         weight-synced vLLM: {tokenizer.decode(expected)!r}")
        print(f"         V2 vLLM:           {tokenizer.decode(actual)!r}")
        for j, (e, a) in enumerate(zip(expected, actual)):
            if e != a:
                print(
                    f"         first divergence at output token {j}: "
                    f"expected {e} ({tokenizer.decode([e])!r}) vs "
                    f"actual {a} ({tokenizer.decode([a])!r})"
                )
                break
402

403
404
ray.get(llm_v2.shutdown.remote())
ray.kill(llm_v2)
405
406
407
408
409
410
411
412
413
414

pass_rate = num_pass / num_total
print(f"\n  Result: {num_pass}/{num_total} prompts passed ({pass_rate:.0%})")
print(f"  Required: >= {MIN_PASS_RATE:.0%}")

assert pass_rate >= MIN_PASS_RATE, (
    f"Validation pass rate {pass_rate:.0%} ({num_pass}/{num_total}) "
    f"is below the required {MIN_PASS_RATE:.0%} threshold. "
    f"See failures above for details."
)
415
print("=" * 50)