Unverified Commit f7da9cdf authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][CI] Support async weight transfer example with platform-aware determinism (#35710)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent f22ff295
...@@ -1339,6 +1339,7 @@ steps: ...@@ -1339,6 +1339,7 @@ steps:
- tests/v1/entrypoints/openai/test_multi_api_servers.py - tests/v1/entrypoints/openai/test_multi_api_servers.py
- tests/v1/shutdown - tests/v1/shutdown
- tests/v1/worker/test_worker_memory_snapshot.py - tests/v1/worker/test_worker_memory_snapshot.py
- examples/offline_inference/new_weight_syncing/
commands: commands:
# Work around HIP bug tracked here: https://github.com/ROCm/hip/issues/3876 # Work around HIP bug tracked here: https://github.com/ROCm/hip/issues/3876
# TODO: Remove when the bug is fixed in a future ROCm release # TODO: Remove when the bug is fixed in a future ROCm release
...@@ -1970,8 +1971,10 @@ steps: ...@@ -1970,8 +1971,10 @@ steps:
- label: Distributed Tests (4 GPUs) # 35min - label: Distributed Tests (4 GPUs) # 35min
timeout_in_minutes: 50 timeout_in_minutes: 50
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi355_4 agent_pool: mi355_4
optional: true
# grade: Blocking
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_gpus: 4 num_gpus: 4
source_file_dependencies: source_file_dependencies:
...@@ -2025,7 +2028,8 @@ steps: ...@@ -2025,7 +2028,8 @@ steps:
- popd - popd
# NEW rlhf examples # NEW rlhf examples
- pushd ../examples/offline_inference/new_weight_syncing - pushd ../examples/offline_inference/new_weight_syncing
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
- popd - popd
...@@ -2989,8 +2993,10 @@ steps: ...@@ -2989,8 +2993,10 @@ steps:
- label: Distributed Tests (2 GPUs) # 68min - label: Distributed Tests (2 GPUs) # 68min
timeout_in_minutes: 90 timeout_in_minutes: 90
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi355_2 agent_pool: mi355_2
optional: true
# grade: Blocking
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_gpus: 2 num_gpus: 2
source_file_dependencies: source_file_dependencies:
......
...@@ -47,12 +47,14 @@ from vllm.distributed.weight_transfer.nccl_engine import ( ...@@ -47,12 +47,14 @@ from vllm.distributed.weight_transfer.nccl_engine import (
NCCLWeightTransferInitInfo, NCCLWeightTransferInitInfo,
NCCLWeightTransferUpdateInfo, NCCLWeightTransferUpdateInfo,
) )
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_ip, get_open_port from vllm.utils.network_utils import get_ip, get_open_port
from vllm.v1.executor import Executor from vllm.v1.executor import Executor
MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base" MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base"
MODEL_NAME_V2 = "Qwen/Qwen3-1.7B" MODEL_NAME_V2 = "Qwen/Qwen3-1.7B"
PAUSE_TOKEN_THRESHOLD = 10 PAUSE_TOKEN_THRESHOLD = 10
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "FLASH_ATTN"
class MyLLM(vllm.AsyncLLMEngine): class MyLLM(vllm.AsyncLLMEngine):
...@@ -116,10 +118,16 @@ class TrainModel: ...@@ -116,10 +118,16 @@ class TrainModel:
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
init_batch_invariance, init_batch_invariance,
) )
from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
# need to init all env vars for batch invariance which affect nccl ops # need to init all env vars for batch invariance which affect nccl ops
init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) attn_backend = (
AttentionBackendEnum.TRITON_ATTN
if current_platform.is_rocm()
else AttentionBackendEnum.FLASH_ATTN
)
init_batch_invariance(attn_backend)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16 model_name, dtype=torch.bfloat16
...@@ -175,39 +183,56 @@ class TrainModel: ...@@ -175,39 +183,56 @@ class TrainModel:
return new_token_ids return new_token_ids
ray.init( # Build platform-specific env vars for Ray
runtime_env={ ray_env_vars = {
"env_vars": { # Prevent Ray from setting CUDA_VISIBLE_DEVICES
# enable batch invariance for deterministic outputs
"VLLM_BATCH_INVARIANT": "1",
# prevent ray from setting CUDA_VISIBLE_DEVICES
"RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1", "RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1",
} }
}
) if current_platform.is_rocm():
# For ROCm, BATCH_INVARIANT vllm is not supported
ray_env_vars["VLLM_ROCM_USE_SKINNY_GEMM"] = "0"
else:
# Enable batch invariance for deterministic outputs on NVIDIA
ray_env_vars["VLLM_BATCH_INVARIANT"] = "1"
ray.init(runtime_env={"env_vars": ray_env_vars})
# Launch the training model actor. Ray's resource scheduler will allocate # Launch the training model actor. Ray's resource scheduler will allocate
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs. # 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
train_model = TrainModel.remote(MODEL_NAME_V2) train_model = TrainModel.remote(MODEL_NAME_V2)
# Launch the vLLM inference engine. The `enforce_eager` flag reduces rocm_determinism_kwargs = {}
# start-up latency. if current_platform.is_rocm():
# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates # ROCm: To minimize non-determinism, we set fixed seed, no prefix caching, and
# its own placement groups internally for each DP rank, so we must NOT # sequential request processing (max_num_seqs=1).
# create an outer placement group (it would reserve GPUs and hide them rocm_determinism_kwargs = {
# from the internal DP resource check). "seed": 0,
llm = ray.remote( "enable_prefix_caching": False,
num_cpus=0, "max_num_seqs": 1,
num_gpus=0, }
)(MyLLM).remote(
# Build platform-specific LLM kwargs
llm_kwargs = dict(
model=MODEL_NAME_V1, model=MODEL_NAME_V1,
enforce_eager=True, enforce_eager=True,
max_model_len=8192, max_model_len=8192,
distributed_executor_backend="ray", distributed_executor_backend="ray",
attention_backend="FLASH_ATTN", attention_backend=ATTN_BACKEND,
gpu_memory_utilization=0.75, gpu_memory_utilization=0.75,
weight_transfer_config=WeightTransferConfig(backend="nccl"), weight_transfer_config=WeightTransferConfig(backend="nccl"),
) )
llm_kwargs.update(rocm_determinism_kwargs)
# Launch the vLLM inference engine.
# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates
# its own placement groups internally for each DP rank, so we must NOT
# create an outer placement group (it would reserve GPUs and hide them
# from the internal DP resource check).
llm = ray.remote(
num_cpus=0,
num_gpus=0,
)(MyLLM).remote(**llm_kwargs)
PROMPTS = [ PROMPTS = [
"The president of the United States is", "The president of the United States is",
...@@ -304,25 +329,42 @@ for i, (output, pause_idx) in enumerate(results): ...@@ -304,25 +329,42 @@ for i, (output, pause_idx) in enumerate(results):
print(f" New weights ({n_after} tokens): {after_text!r}") print(f" New weights ({n_after} tokens): {after_text!r}")
# ── Phase 2: validate with a fresh V2 vLLM instance ──────────────── # ── Phase 2: validate with a fresh V2 vLLM instance ────────────────
# This validation relies on batch-invariant (deterministic) generation to
# compare outputs from the weight-synced engine against a fresh V2 instance.
# On NVIDIA, batch invariance is fully supported, so we require 100% exact
# token match. On ROCm, batch invariance is not yet fully implemented
# (see https://github.com/vllm-project/vllm/issues/27433 and
# https://github.com/vllm-project/vllm/issues/33123), so residual
# non-determinism (e.g. GEMM accumulation order, missing kernel overrides)
# can cause single-token divergences that don't indicate a weight-sync
# failure. We relax the pass rate to 90% on ROCm to accommodate this; a
# real regression (broken weight transfer) would cause ~0% pass rate, not 90%+.
MIN_PASS_RATE = 1.0 if not current_platform.is_rocm() else 0.9
print(f"\n{'=' * 50}") print(f"\n{'=' * 50}")
print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance") print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance")
if current_platform.is_rocm():
print(f" (ROCm mode: requiring >= {MIN_PASS_RATE:.0%} exact match rate)")
print(f"{'=' * 50}") print(f"{'=' * 50}")
ray.get(llm.shutdown.remote()) ray.get(llm.shutdown.remote())
ray.kill(llm) ray.kill(llm)
ray.kill(train_model) ray.kill(train_model)
llm_v2 = ray.remote( llm_v2_kwargs = dict(
num_cpus=0,
num_gpus=0,
)(MyLLM).remote(
model=MODEL_NAME_V2, model=MODEL_NAME_V2,
enforce_eager=True, enforce_eager=True,
max_model_len=8192, max_model_len=8192,
gpu_memory_utilization=0.75, gpu_memory_utilization=0.75,
distributed_executor_backend="ray", distributed_executor_backend="ray",
attention_backend="FLASH_ATTN", attention_backend=ATTN_BACKEND,
) )
llm_v2_kwargs.update(rocm_determinism_kwargs)
llm_v2 = ray.remote(
num_cpus=0,
num_gpus=0,
)(MyLLM).remote(**llm_v2_kwargs)
val_futures = [ val_futures = [
llm_v2.do_generate.remote( llm_v2.do_generate.remote(
...@@ -335,16 +377,17 @@ val_futures = [ ...@@ -335,16 +377,17 @@ val_futures = [
] ]
val_results = ray.get(val_futures) val_results = ray.get(val_futures)
all_pass = True num_pass = 0
num_total = len(results)
for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)): for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)):
expected = list(output.outputs[0].token_ids)[pause_idx:] expected = list(output.outputs[0].token_ids)[pause_idx:]
actual = list(val_output.outputs[0].token_ids) actual = list(val_output.outputs[0].token_ids)
match = actual == expected match = actual == expected
if match: if match:
num_pass += 1
print(f" [PASS] {PROMPTS[i]!r}") print(f" [PASS] {PROMPTS[i]!r}")
else: else:
all_pass = False
print(f" [FAIL] {PROMPTS[i]!r}") print(f" [FAIL] {PROMPTS[i]!r}")
print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}") print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}")
print(f" V2 vLLM: {tokenizer.decode(actual)!r}") print(f" V2 vLLM: {tokenizer.decode(actual)!r}")
...@@ -359,5 +402,14 @@ for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_resu ...@@ -359,5 +402,14 @@ for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_resu
ray.get(llm_v2.shutdown.remote()) ray.get(llm_v2.shutdown.remote())
ray.kill(llm_v2) ray.kill(llm_v2)
assert all_pass, "Some prompts failed validation, see above for details"
pass_rate = num_pass / num_total
print(f"\n Result: {num_pass}/{num_total} prompts passed ({pass_rate:.0%})")
print(f" Required: >= {MIN_PASS_RATE:.0%}")
assert pass_rate >= MIN_PASS_RATE, (
f"Validation pass rate {pass_rate:.0%} ({num_pass}/{num_total}) "
f"is below the required {MIN_PASS_RATE:.0%} threshold. "
f"See failures above for details."
)
print("=" * 50) print("=" * 50)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment