Unverified Commit 596ed1f0 authored by Aaron Hao's avatar Aaron Hao Committed by GitHub
Browse files

[RL] Validation for pause_mode='keep' (#34992)


Signed-off-by: default avatarahao-anyscale <ahao@anyscale.com>
parent b8d8b7e9
...@@ -104,7 +104,6 @@ steps: ...@@ -104,7 +104,6 @@ steps:
# NEW rlhf examples # NEW rlhf examples
- cd new_weight_syncing - cd new_weight_syncing
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
- label: Distributed Tests (8 GPUs)(H100) - label: Distributed Tests (8 GPUs)(H100)
timeout_in_minutes: 10 timeout_in_minutes: 10
...@@ -146,6 +145,7 @@ steps: ...@@ -146,6 +145,7 @@ steps:
num_devices: 2 num_devices: 2
commands: commands:
- pytest -v -s tests/distributed/test_context_parallel.py - pytest -v -s tests/distributed/test_context_parallel.py
- cd examples/offline_inference/new_weight_syncing && VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
- VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=deepep_high_throughput - VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=deepep_high_throughput
- pytest -v -s tests/v1/distributed/test_dbo.py - pytest -v -s tests/v1/distributed/test_dbo.py
......
...@@ -26,14 +26,12 @@ workloads. Residual GPU activity interferes with vLLM memory profiling and ...@@ -26,14 +26,12 @@ workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior. causes unexpected behavior.
""" """
import os import asyncio
import uuid import uuid
from dataclasses import asdict from dataclasses import asdict
import ray import ray
import torch import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
import vllm import vllm
...@@ -51,14 +49,15 @@ from vllm.distributed.weight_transfer.nccl_engine import ( ...@@ -51,14 +49,15 @@ from vllm.distributed.weight_transfer.nccl_engine import (
from vllm.utils.network_utils import get_ip, get_open_port from vllm.utils.network_utils import get_ip, get_open_port
from vllm.v1.executor import Executor from vllm.v1.executor import Executor
MODEL_NAME = "facebook/opt-125m" MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base"
MODEL_NAME_V2 = "Qwen/Qwen3-1.7B"
PAUSE_TOKEN_THRESHOLD = 10
class MyLLM(vllm.AsyncLLMEngine): class MyLLM(vllm.AsyncLLMEngine):
"""Configure the vLLM worker for Ray placement group execution.""" """Configure the vLLM worker for Ray placement group execution."""
def __init__(self, **kwargs): def __init__(self, **kwargs):
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1"
engine_args = vllm.AsyncEngineArgs(**kwargs) engine_args = vllm.AsyncEngineArgs(**kwargs)
vllm_config = engine_args.create_engine_config() vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
...@@ -68,26 +67,44 @@ class MyLLM(vllm.AsyncLLMEngine): ...@@ -68,26 +67,44 @@ class MyLLM(vllm.AsyncLLMEngine):
log_requests=engine_args.enable_log_requests, log_requests=engine_args.enable_log_requests,
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
) )
self._generation_paused = False
self._request_pause_flag = False
async def generate_with_retry( async def do_generate(
self, prompt_token_ids: list[int], sampling_params: vllm.SamplingParams self, prompt_token_ids: list[int], sampling_params: vllm.SamplingParams
) -> vllm.RequestOutput: ) -> tuple[vllm.RequestOutput, int]:
finish_reason = "abort" """Generate a single request, setting the request pause flag once the
while finish_reason == "abort": token count reaches the threshold.
async for request_output in self.generate(
{"prompt_token_ids": prompt_token_ids}, Returns (output, pause_token_index). pause_token_index is the number
sampling_params, of tokens generated before the weight change, or -1 if no pause.
request_id=str(uuid.uuid4()), """
pause_token_index = -1
prev_token_count = 0
async for request_output in self.generate(
{"prompt_token_ids": prompt_token_ids},
sampling_params,
request_id=str(uuid.uuid4()),
):
output = request_output
cur_token_count = len(output.outputs[0].token_ids)
if (
cur_token_count >= PAUSE_TOKEN_THRESHOLD
and not self._request_pause_flag
): ):
output = request_output self._request_pause_flag = True
finish_reason = output.outputs[0].finish_reason if self._generation_paused and pause_token_index == -1:
if finish_reason == "abort": pause_token_index = prev_token_count
print( prev_token_count = cur_token_count
f"ABORT, prompt_token_ids: {prompt_token_ids}, " return output, pause_token_index
f"generated token_ids: {list(output.outputs[0].token_ids)}"
) async def pause_after_n_tokens(self):
prompt_token_ids = prompt_token_ids + list(output.outputs[0].token_ids) """Wait for any request to set the pause flag, then pause."""
return output while not self._request_pause_flag:
await asyncio.sleep(0)
await super().pause_generation(mode="keep")
await asyncio.sleep(0.2)
self._generation_paused = True
@ray.remote(num_gpus=1) @ray.remote(num_gpus=1)
...@@ -95,6 +112,14 @@ class TrainModel: ...@@ -95,6 +112,14 @@ class TrainModel:
"""Ray actor that wraps the training model on a dedicated GPU.""" """Ray actor that wraps the training model on a dedicated GPU."""
def __init__(self, model_name: str): def __init__(self, model_name: str):
from vllm.model_executor.layers.batch_invariant import (
init_batch_invariance,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
# need to init all env vars for batch invariance which affect nccl ops
init_batch_invariance(AttentionBackendEnum.FLASH_ATTN)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16 model_name, dtype=torch.bfloat16
).to("cuda:0") ).to("cuda:0")
...@@ -133,70 +158,80 @@ class TrainModel: ...@@ -133,70 +158,80 @@ class TrainModel:
packed=packed, packed=packed,
) )
@torch.inference_mode()
# Initialize Ray and set the visible devices. The vLLM engine will def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]:
# be placed on GPUs 1 and 2. """Greedy-decode max_new_tokens from the given context."""
ray.init() input_ids = torch.tensor([token_ids], device="cuda:0")
output = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=False,
)
new_token_ids = output[0, len(token_ids) :].tolist()
return new_token_ids
ray.init(
runtime_env={
"env_vars": {
# enable batch invariance for deterministic outputs
"VLLM_BATCH_INVARIANT": "1",
# prevent ray from setting CUDA_VISIBLE_DEVICES
"RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1",
}
}
)
# Launch the training model actor. Ray's resource scheduler will allocate # Launch the training model actor. Ray's resource scheduler will allocate
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs. # 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
train_model = TrainModel.remote(MODEL_NAME) train_model = TrainModel.remote(MODEL_NAME_V2)
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/placement-groups.html
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group=pg_inference,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=0,
)
# Launch the vLLM inference engine. The `enforce_eager` flag reduces # Launch the vLLM inference engine. The `enforce_eager` flag reduces
# start-up latency. # start-up latency.
# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights) # With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates
# are now native to vLLM workers. # its own placement groups internally for each DP rank, so we must NOT
# create an outer placement group (it would reserve GPUs and hide them
# from the internal DP resource check).
llm = ray.remote( llm = ray.remote(
num_cpus=0, num_cpus=0,
num_gpus=0, num_gpus=0,
scheduling_strategy=scheduling_inference,
)(MyLLM).remote( )(MyLLM).remote(
model=MODEL_NAME, model=MODEL_NAME_V1,
enforce_eager=True, enforce_eager=True,
tensor_parallel_size=2, max_model_len=8192,
distributed_executor_backend="ray", distributed_executor_backend="ray",
load_format="dummy", attention_backend="FLASH_ATTN",
gpu_memory_utilization=0.75,
weight_transfer_config=WeightTransferConfig(backend="nccl"), weight_transfer_config=WeightTransferConfig(backend="nccl"),
) )
# Generate text from the prompts. PROMPTS = [
prompts = [
"My name is",
"The president of the United States is", "The president of the United States is",
"The capital of France is", "The capital of France is",
"The future of AI is", "The largest ocean on Earth is",
"The speed of light in a vacuum is",
"The chemical formula for water is",
"The tallest mountain in the world is",
"The first person to walk on the moon was",
"The Great Wall of China was built to",
"Photosynthesis is the process by which",
"The theory of general relativity was proposed by",
"The boiling point of water at sea level is",
"The largest planet in our solar system is",
"DNA stands for deoxyribonucleic acid and it",
] ]
# Tokenize prompts to token IDs tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_V1)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) batch_prompt_token_ids = [
prompt_token_ids_list = [ tokenizer.encode(prompt, add_special_tokens=False) for prompt in PROMPTS
tokenizer.encode(prompt, add_special_tokens=False) for prompt in prompts
] ]
sampling_params = [
SamplingParams(temperature=0, max_tokens=2),
SamplingParams(temperature=0, max_tokens=32),
SamplingParams(temperature=0, max_tokens=32),
SamplingParams(temperature=0, max_tokens=32),
]
# Set up the communication channel between the training process and the # Set up the communication channel between the training process and the
# inference engine. # inference engine.
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote()) master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
world_size = 3 # 1 trainer + 2 inference workers (tensor_parallel_size=2) world_size = 2 # 1 trainer + 1 inference worker
inference_handle = llm.init_weight_transfer_engine.remote( inference_handle = llm.init_weight_transfer_engine.remote(
WeightTransferInitRequest( WeightTransferInitRequest(
init_info=asdict( init_info=asdict(
...@@ -215,22 +250,28 @@ train_handle = train_model.init_weight_transfer_group.remote(world_size) ...@@ -215,22 +250,28 @@ train_handle = train_model.init_weight_transfer_group.remote(world_size)
ray.get([train_handle, inference_handle]) ray.get([train_handle, inference_handle])
generation_futures = [ N_NEW_TOKENS = 100
llm.generate_with_retry.remote(prompt_token_ids, params)
for prompt_token_ids, params in zip(prompt_token_ids_list, sampling_params)
]
finished, pending = ray.wait(generation_futures, num_returns=1) # Collect weight metadata once
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
# Pause generation in preparation for weight sync # ── Phase 1: concurrent requests with weight sync ───────────────────
ray.get(llm.pause_generation.remote(wait_for_inflight_requests=False)) print(f"\n{'=' * 50}")
print(f"Prompts ({len(PROMPTS)}):")
for p in PROMPTS:
print(f" - {p!r}")
print(f"{'=' * 50}")
# Synchronize the updated weights to the inference engine using batched API. sampling_params = SamplingParams(
# Collect all weight metadata from the training actor temperature=0, max_tokens=PAUSE_TOKEN_THRESHOLD + N_NEW_TOKENS
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote()) )
gen_futures = [
llm.do_generate.remote(ptids, sampling_params) for ptids in batch_prompt_token_ids
]
ray.get(llm.pause_after_n_tokens.remote())
# Issue update_weights call with NCCL-specific update info
# packed=True enables efficient batched tensor broadcasting
inference_handle = llm.update_weights.remote( inference_handle = llm.update_weights.remote(
WeightTransferUpdateRequest( WeightTransferUpdateRequest(
update_info=asdict( update_info=asdict(
...@@ -243,41 +284,76 @@ inference_handle = llm.update_weights.remote( ...@@ -243,41 +284,76 @@ inference_handle = llm.update_weights.remote(
) )
) )
) )
# Broadcast all weights from trainer using the weight transfer API
train_handle = train_model.broadcast_weights.remote(packed=True) train_handle = train_model.broadcast_weights.remote(packed=True)
ray.get([train_handle, inference_handle]) ray.get([train_handle, inference_handle])
# Resume generation since weight sync is complete
ray.get(llm.resume_generation.remote()) ray.get(llm.resume_generation.remote())
results = ray.get(gen_futures)
for i, (output, pause_idx) in enumerate(results):
all_token_ids = list(output.outputs[0].token_ids)
before_text = tokenizer.decode(all_token_ids[:pause_idx])
after_text = tokenizer.decode(all_token_ids[pause_idx:])
print(f"\n Request {i} ({PROMPTS[i]!r}):")
print(f" Old weights ({pause_idx} tokens): {before_text!r}")
n_after = len(all_token_ids) - pause_idx
print(f" New weights ({n_after} tokens): {after_text!r}")
# ── Phase 2: validate with a fresh V2 vLLM instance ────────────────
print(f"\n{'=' * 50}")
print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance")
print(f"{'=' * 50}")
ray.get(llm.shutdown.remote())
ray.kill(llm)
ray.kill(train_model)
llm_v2 = ray.remote(
num_cpus=0,
num_gpus=0,
)(MyLLM).remote(
model=MODEL_NAME_V2,
enforce_eager=True,
max_model_len=8192,
gpu_memory_utilization=0.75,
distributed_executor_backend="ray",
attention_backend="FLASH_ATTN",
)
val_futures = [
llm_v2.do_generate.remote(
list(output.prompt_token_ids) + list(output.outputs[0].token_ids)[:pause_idx],
SamplingParams(
temperature=0, max_tokens=len(output.outputs[0].token_ids) - pause_idx
),
)
for output, pause_idx in results
]
val_results = ray.get(val_futures)
all_pass = True
for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)):
expected = list(output.outputs[0].token_ids)[pause_idx:]
actual = list(val_output.outputs[0].token_ids)
match = actual == expected
if match:
print(f" [PASS] {PROMPTS[i]!r}")
else:
all_pass = False
print(f" [FAIL] {PROMPTS[i]!r}")
print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}")
print(f" V2 vLLM: {tokenizer.decode(actual)!r}")
for j, (e, a) in enumerate(zip(expected, actual)):
if e != a:
print(
f" first divergence at output token {j}: "
f"expected {e} ({tokenizer.decode([e])!r}) vs "
f"actual {a} ({tokenizer.decode([a])!r})"
)
break
# Get outputs separately - finished completed before pause, pending were paused/resumed ray.get(llm_v2.shutdown.remote())
finished_outputs = ray.get(finished) ray.kill(llm_v2)
pending_outputs = ray.get(pending) assert all_pass, "Some prompts failed validation, see above for details"
print("=" * 50)
# Requests that finished before the pause: all generation used original weights
print("-" * 50)
print("Requests that completed BEFORE weight change:")
print("-" * 50)
for output in finished_outputs:
prompt_text = tokenizer.decode(output.prompt_token_ids)
print(f"Prompt: {prompt_text!r}")
print(f"Generated (with original weights): {output.outputs[0].text!r}")
print("-" * 50)
# Requests that were paused mid-generation: some text before, some after weight change
print("Requests that were PAUSED and RESUMED after weight change:")
print("-" * 50)
for output in pending_outputs:
# Decode the full prompt token IDs (original + generated before pause)
full_prompt_text = tokenizer.decode(output.prompt_token_ids)
# Find the original prompt by checking which one this output started with
original_prompt = next(p for p in prompts if full_prompt_text.startswith(p))
# output.prompt_token_ids contains original prompt + tokens generated before pause
# output.outputs[0].text is what was generated after resuming with new weights
text_before_pause = full_prompt_text[len(original_prompt) :]
text_after_pause = output.outputs[0].text
print(f"Original prompt: {original_prompt!r}")
print(f"Generated before weight change: {text_before_pause!r}")
print(f"Generated after weight change: {text_after_pause!r}")
print("-" * 50)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment