rlhf_async_new_apis.py 14.9 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates async reinforcement learning using vLLM and Ray,
5
with native weight syncing APIs and batch-invariant generation.
6
7
8

The script separates training and inference workloads onto distinct GPUs
so that Ray can manage process placement and inter-process communication.
9
10
11
12
13
14
15
16
17
A Hugging Face Transformer model occupies one GPU for training, and a
vLLM AsyncLLMEngine occupies another GPU for inference.

Batch invariance is enabled so that generation output is deterministic
regardless of how many requests are batched together. This is required
for the validation phase to succeed. Batch invariance currently requires
NVIDIA GPUs with compute capability 9.0 or higher:
  - H-series: H100, H200
  - B-series: B100, B200
18
19

The example performs the following steps:
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
* Load the training model (Qwen3-1.7B) on one GPU via a Ray actor.
* Initialize the inference engine with a base model (Qwen3-1.7B-Base)
  on a separate GPU using vLLM's AsyncLLMEngine with Ray as the
  distributed executor backend.
* Set up an NCCL-based weight transfer channel between the trainer
  and the inference engine.
* Submit generation requests for a batch of prompts.
* Pause generation once any request reaches a token threshold.
* Broadcast the training model's weights to the inference engine
  via the NCCL weight transfer engine, replacing the base weights.
* Resume generation and collect results, noting which tokens were
  generated before vs. after the weight swap.
* Validate correctness by launching a fresh vLLM instance loaded
  directly with the training model and comparing its output to the
  post-swap tokens from the weight-synced engine.
35

36
This example assumes a single-node cluster with two GPUs, but Ray
37
38
39
40
41
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
"""

42
import asyncio
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import uuid
from dataclasses import asdict

import ray
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

import vllm
from vllm import SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
from vllm.distributed.weight_transfer.nccl_engine import (
58
    NCCLTrainerSendWeightsArgs,
59
60
61
62
    NCCLWeightTransferEngine,
    NCCLWeightTransferInitInfo,
    NCCLWeightTransferUpdateInfo,
)
63
from vllm.platforms import current_platform
64
65
66
from vllm.utils.network_utils import get_ip, get_open_port
from vllm.v1.executor import Executor

67
68
69
MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base"
MODEL_NAME_V2 = "Qwen/Qwen3-1.7B"
PAUSE_TOKEN_THRESHOLD = 10
70
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "FLASH_ATTN"
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85


class MyLLM(vllm.AsyncLLMEngine):
    """Configure the vLLM worker for Ray placement group execution."""

    def __init__(self, **kwargs):
        engine_args = vllm.AsyncEngineArgs(**kwargs)
        vllm_config = engine_args.create_engine_config()
        executor_class = Executor.get_class(vllm_config)
        super().__init__(
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_requests=engine_args.enable_log_requests,
            log_stats=not engine_args.disable_log_stats,
        )
86
87
        self._generation_paused = False
        self._request_pause_flag = False
88

89
    async def do_generate(
90
        self, prompt_token_ids: list[int], sampling_params: vllm.SamplingParams
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    ) -> tuple[vllm.RequestOutput, int]:
        """Generate a single request, setting the request pause flag once the
        token count reaches the threshold.

        Returns (output, pause_token_index). pause_token_index is the number
        of tokens generated before the weight change, or -1 if no pause.
        """
        pause_token_index = -1
        prev_token_count = 0
        async for request_output in self.generate(
            {"prompt_token_ids": prompt_token_ids},
            sampling_params,
            request_id=str(uuid.uuid4()),
        ):
            output = request_output
            cur_token_count = len(output.outputs[0].token_ids)
            if (
                cur_token_count >= PAUSE_TOKEN_THRESHOLD
                and not self._request_pause_flag
110
            ):
111
112
113
114
115
116
117
118
119
120
121
                self._request_pause_flag = True
            if self._generation_paused and pause_token_index == -1:
                pause_token_index = prev_token_count
            prev_token_count = cur_token_count
        return output, pause_token_index

    async def pause_after_n_tokens(self):
        """Wait for any request to set the pause flag, then pause."""
        while not self._request_pause_flag:
            await asyncio.sleep(0)
        await super().pause_generation(mode="keep")
122
        await asyncio.sleep(5)
123
        self._generation_paused = True
124
125
126
127
128
129
130


@ray.remote(num_gpus=1)
class TrainModel:
    """Ray actor that wraps the training model on a dedicated GPU."""

    def __init__(self, model_name: str):
131
132
133
134
135
        from vllm.model_executor.layers.batch_invariant import (
            init_batch_invariance,
        )

        # need to init all env vars for batch invariance which affect nccl ops
136
        init_batch_invariance()
137

138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name, dtype=torch.bfloat16
        ).to("cuda:0")
        self.port = get_open_port()
        self.master_address = get_ip()

    def get_master_address_and_port(self):
        return self.master_address, self.port

    def get_weight_metadata(self):
        """Return weight names, dtypes, and shapes for weight transfer."""
        names = []
        dtype_names = []
        shapes = []
        for name, p in self.model.named_parameters():
            names.append(name)
            dtype_names.append(str(p.dtype).split(".")[-1])
            shapes.append(list(p.shape))
        return names, dtype_names, shapes

    def init_weight_transfer_group(self, world_size):
        """Initialize the NCCL process group for weight transfer."""
        self.model_update_group = NCCLWeightTransferEngine.trainer_init(
            dict(
                master_address=self.master_address,
                master_port=self.port,
                world_size=world_size,
            ),
        )

    def broadcast_weights(self, packed: bool = True):
        """Broadcast weights to the inference engine."""
170
        trainer_args = NCCLTrainerSendWeightsArgs(
171
172
173
            group=self.model_update_group,
            packed=packed,
        )
174
175
176
177
        NCCLWeightTransferEngine.trainer_send_weights(
            iterator=self.model.named_parameters(),
            trainer_args=trainer_args,
        )
178

179
180
181
182
183
184
185
186
187
188
189
190
191
    @torch.inference_mode()
    def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]:
        """Greedy-decode max_new_tokens from the given context."""
        input_ids = torch.tensor([token_ids], device="cuda:0")
        output = self.model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=False,
        )
        new_token_ids = output[0, len(token_ids) :].tolist()
        return new_token_ids


192
193
194
195
196
197
198
199
200
201
202
203
204
205
# Build platform-specific env vars for Ray
ray_env_vars = {
    # Prevent Ray from setting CUDA_VISIBLE_DEVICES
    "RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1",
}

if current_platform.is_rocm():
    # For ROCm, BATCH_INVARIANT vllm is not supported
    ray_env_vars["VLLM_ROCM_USE_SKINNY_GEMM"] = "0"
else:
    # Enable batch invariance for deterministic outputs on NVIDIA
    ray_env_vars["VLLM_BATCH_INVARIANT"] = "1"

ray.init(runtime_env={"env_vars": ray_env_vars})
206
207
208

# Launch the training model actor. Ray's resource scheduler will allocate
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
209
train_model = TrainModel.remote(MODEL_NAME_V2)
210

211
212
213
214
215
216
217
218
219
220
221
222
rocm_determinism_kwargs = {}
if current_platform.is_rocm():
    # ROCm: To minimize non-determinism, we set fixed seed, no prefix caching, and
    # sequential request processing (max_num_seqs=1).
    rocm_determinism_kwargs = {
        "seed": 0,
        "enable_prefix_caching": False,
        "max_num_seqs": 1,
    }

# Build platform-specific LLM kwargs
llm_kwargs = dict(
223
    model=MODEL_NAME_V1,
224
    enforce_eager=True,
225
    max_model_len=8192,
226
    distributed_executor_backend="ray",
227
    attention_backend=ATTN_BACKEND,
228
    gpu_memory_utilization=0.75,
229
230
    weight_transfer_config=WeightTransferConfig(backend="nccl"),
)
231
232
233
234
235
236
237
238
239
240
241
llm_kwargs.update(rocm_determinism_kwargs)

# Launch the vLLM inference engine.
# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates
# its own placement groups internally for each DP rank, so we must NOT
# create an outer placement group (it would reserve GPUs and hide them
# from the internal DP resource check).
llm = ray.remote(
    num_cpus=0,
    num_gpus=0,
)(MyLLM).remote(**llm_kwargs)
242

243
PROMPTS = [
244
245
    "The president of the United States is",
    "The capital of France is",
246
247
248
249
250
251
252
253
254
255
256
    "The largest ocean on Earth is",
    "The speed of light in a vacuum is",
    "The chemical formula for water is",
    "The tallest mountain in the world is",
    "The first person to walk on the moon was",
    "The Great Wall of China was built to",
    "Photosynthesis is the process by which",
    "The theory of general relativity was proposed by",
    "The boiling point of water at sea level is",
    "The largest planet in our solar system is",
    "DNA stands for deoxyribonucleic acid and it",
257
258
]

259
260
261
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_V1)
batch_prompt_token_ids = [
    tokenizer.encode(prompt, add_special_tokens=False) for prompt in PROMPTS
262
263
264
265
266
267
268
]


# Set up the communication channel between the training process and the
# inference engine.
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())

269
world_size = 2  # 1 trainer + 1 inference worker
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
inference_handle = llm.init_weight_transfer_engine.remote(
    WeightTransferInitRequest(
        init_info=asdict(
            NCCLWeightTransferInitInfo(
                master_address=master_address,
                master_port=master_port,
                rank_offset=1,
                world_size=world_size,
            )
        )
    )
)

# Initialize weight transfer group on both the training actor and inference engine
train_handle = train_model.init_weight_transfer_group.remote(world_size)
ray.get([train_handle, inference_handle])


288
N_NEW_TOKENS = 100
289

290
291
# Collect weight metadata once
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
292

293
294
295
296
297
298
# ── Phase 1: concurrent requests with weight sync ───────────────────
print(f"\n{'=' * 50}")
print(f"Prompts ({len(PROMPTS)}):")
for p in PROMPTS:
    print(f"  - {p!r}")
print(f"{'=' * 50}")
299

300
301
302
303
304
305
306
307
308
sampling_params = SamplingParams(
    temperature=0, max_tokens=PAUSE_TOKEN_THRESHOLD + N_NEW_TOKENS
)

gen_futures = [
    llm.do_generate.remote(ptids, sampling_params) for ptids in batch_prompt_token_ids
]

ray.get(llm.pause_after_n_tokens.remote())
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325

inference_handle = llm.update_weights.remote(
    WeightTransferUpdateRequest(
        update_info=asdict(
            NCCLWeightTransferUpdateInfo(
                names=names,
                dtype_names=dtype_names,
                shapes=shapes,
                packed=True,
            )
        )
    )
)
train_handle = train_model.broadcast_weights.remote(packed=True)
ray.get([train_handle, inference_handle])

ray.get(llm.resume_generation.remote())
326
327
328
329
330
331
332
333
334
335
336
337
results = ray.get(gen_futures)

for i, (output, pause_idx) in enumerate(results):
    all_token_ids = list(output.outputs[0].token_ids)
    before_text = tokenizer.decode(all_token_ids[:pause_idx])
    after_text = tokenizer.decode(all_token_ids[pause_idx:])
    print(f"\n  Request {i} ({PROMPTS[i]!r}):")
    print(f"    Old weights ({pause_idx} tokens): {before_text!r}")
    n_after = len(all_token_ids) - pause_idx
    print(f"    New weights ({n_after} tokens): {after_text!r}")

# ── Phase 2: validate with a fresh V2 vLLM instance ────────────────
338
339
340
341
342
343
344
345
346
347
348
349
# This validation relies on batch-invariant (deterministic) generation to
# compare outputs from the weight-synced engine against a fresh V2 instance.
# On NVIDIA, batch invariance is fully supported, so we require 100% exact
# token match. On ROCm, batch invariance is not yet fully implemented
# (see https://github.com/vllm-project/vllm/issues/27433 and
# https://github.com/vllm-project/vllm/issues/33123), so residual
# non-determinism (e.g. GEMM accumulation order, missing kernel overrides)
# can cause single-token divergences that don't indicate a weight-sync
# failure. We relax the pass rate to 90% on ROCm to accommodate this; a
# real regression (broken weight transfer) would cause ~0% pass rate, not 90%+.
MIN_PASS_RATE = 1.0 if not current_platform.is_rocm() else 0.9

350
351
print(f"\n{'=' * 50}")
print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance")
352
353
if current_platform.is_rocm():
    print(f"  (ROCm mode: requiring >= {MIN_PASS_RATE:.0%} exact match rate)")
354
355
356
357
358
359
print(f"{'=' * 50}")

ray.get(llm.shutdown.remote())
ray.kill(llm)
ray.kill(train_model)

360
llm_v2_kwargs = dict(
361
362
363
364
365
    model=MODEL_NAME_V2,
    enforce_eager=True,
    max_model_len=8192,
    gpu_memory_utilization=0.75,
    distributed_executor_backend="ray",
366
    attention_backend=ATTN_BACKEND,
367
)
368
369
370
371
372
373
llm_v2_kwargs.update(rocm_determinism_kwargs)

llm_v2 = ray.remote(
    num_cpus=0,
    num_gpus=0,
)(MyLLM).remote(**llm_v2_kwargs)
374
375
376
377
378
379
380
381
382
383
384
385

val_futures = [
    llm_v2.do_generate.remote(
        list(output.prompt_token_ids) + list(output.outputs[0].token_ids)[:pause_idx],
        SamplingParams(
            temperature=0, max_tokens=len(output.outputs[0].token_ids) - pause_idx
        ),
    )
    for output, pause_idx in results
]
val_results = ray.get(val_futures)

386
387
num_pass = 0
num_total = len(results)
388
389
390
391
392
393
for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)):
    expected = list(output.outputs[0].token_ids)[pause_idx:]
    actual = list(val_output.outputs[0].token_ids)
    match = actual == expected

    if match:
394
        num_pass += 1
395
396
397
398
399
400
401
402
403
404
405
406
407
        print(f"  [PASS] {PROMPTS[i]!r}")
    else:
        print(f"  [FAIL] {PROMPTS[i]!r}")
        print(f"         weight-synced vLLM: {tokenizer.decode(expected)!r}")
        print(f"         V2 vLLM:           {tokenizer.decode(actual)!r}")
        for j, (e, a) in enumerate(zip(expected, actual)):
            if e != a:
                print(
                    f"         first divergence at output token {j}: "
                    f"expected {e} ({tokenizer.decode([e])!r}) vs "
                    f"actual {a} ({tokenizer.decode([a])!r})"
                )
                break
408

409
410
ray.get(llm_v2.shutdown.remote())
ray.kill(llm_v2)
411
412
413
414
415
416
417
418
419
420

pass_rate = num_pass / num_total
print(f"\n  Result: {num_pass}/{num_total} prompts passed ({pass_rate:.0%})")
print(f"  Required: >= {MIN_PASS_RATE:.0%}")

assert pass_rate >= MIN_PASS_RATE, (
    f"Validation pass rate {pass_rate:.0%} ({num_pass}/{num_total}) "
    f"is below the required {MIN_PASS_RATE:.0%} threshold. "
    f"See failures above for details."
)
421
print("=" * 50)