Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
596ed1f0
Unverified
Commit
596ed1f0
authored
Feb 23, 2026
by
Aaron Hao
Committed by
GitHub
Feb 23, 2026
Browse files
[RL] Validation for pause_mode='keep' (#34992)
Signed-off-by:
ahao-anyscale
<
ahao@anyscale.com
>
parent
b8d8b7e9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
182 additions
and
106 deletions
+182
-106
.buildkite/test_areas/distributed.yaml
.buildkite/test_areas/distributed.yaml
+1
-1
examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py
...fline_inference/new_weight_syncing/rlhf_async_new_apis.py
+181
-105
No files found.
.buildkite/test_areas/distributed.yaml
View file @
596ed1f0
...
@@ -104,7 +104,6 @@ steps:
...
@@ -104,7 +104,6 @@ steps:
# NEW rlhf examples
# NEW rlhf examples
-
cd new_weight_syncing
-
cd new_weight_syncing
-
VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
-
VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
-
VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
-
label
:
Distributed Tests (8 GPUs)(H100)
-
label
:
Distributed Tests (8 GPUs)(H100)
timeout_in_minutes
:
10
timeout_in_minutes
:
10
...
@@ -146,6 +145,7 @@ steps:
...
@@ -146,6 +145,7 @@ steps:
num_devices
:
2
num_devices
:
2
commands
:
commands
:
-
pytest -v -s tests/distributed/test_context_parallel.py
-
pytest -v -s tests/distributed/test_context_parallel.py
-
cd examples/offline_inference/new_weight_syncing && VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
-
VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=deepep_high_throughput
-
VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=deepep_high_throughput
-
pytest -v -s tests/v1/distributed/test_dbo.py
-
pytest -v -s tests/v1/distributed/test_dbo.py
...
...
examples/offline_inference/new_weight_syncing/rlhf_async_new_apis.py
View file @
596ed1f0
...
@@ -26,14 +26,12 @@ workloads. Residual GPU activity interferes with vLLM memory profiling and
...
@@ -26,14 +26,12 @@ workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
causes unexpected behavior.
"""
"""
import
o
s
import
asynci
o
import
uuid
import
uuid
from
dataclasses
import
asdict
from
dataclasses
import
asdict
import
ray
import
ray
import
torch
import
torch
from
ray.util.placement_group
import
placement_group
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
import
vllm
import
vllm
...
@@ -51,14 +49,15 @@ from vllm.distributed.weight_transfer.nccl_engine import (
...
@@ -51,14 +49,15 @@ from vllm.distributed.weight_transfer.nccl_engine import (
from
vllm.utils.network_utils
import
get_ip
,
get_open_port
from
vllm.utils.network_utils
import
get_ip
,
get_open_port
from
vllm.v1.executor
import
Executor
from
vllm.v1.executor
import
Executor
MODEL_NAME
=
"facebook/opt-125m"
MODEL_NAME_V1
=
"Qwen/Qwen3-1.7B-Base"
MODEL_NAME_V2
=
"Qwen/Qwen3-1.7B"
PAUSE_TOKEN_THRESHOLD
=
10
class
MyLLM
(
vllm
.
AsyncLLMEngine
):
class
MyLLM
(
vllm
.
AsyncLLMEngine
):
"""Configure the vLLM worker for Ray placement group execution."""
"""Configure the vLLM worker for Ray placement group execution."""
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
os
.
environ
[
"VLLM_RAY_BUNDLE_INDICES"
]
=
"0,1"
engine_args
=
vllm
.
AsyncEngineArgs
(
**
kwargs
)
engine_args
=
vllm
.
AsyncEngineArgs
(
**
kwargs
)
vllm_config
=
engine_args
.
create_engine_config
()
vllm_config
=
engine_args
.
create_engine_config
()
executor_class
=
Executor
.
get_class
(
vllm_config
)
executor_class
=
Executor
.
get_class
(
vllm_config
)
...
@@ -68,26 +67,44 @@ class MyLLM(vllm.AsyncLLMEngine):
...
@@ -68,26 +67,44 @@ class MyLLM(vllm.AsyncLLMEngine):
log_requests
=
engine_args
.
enable_log_requests
,
log_requests
=
engine_args
.
enable_log_requests
,
log_stats
=
not
engine_args
.
disable_log_stats
,
log_stats
=
not
engine_args
.
disable_log_stats
,
)
)
self
.
_generation_paused
=
False
self
.
_request_pause_flag
=
False
async
def
generate
_with_retry
(
async
def
do_
generate
(
self
,
prompt_token_ids
:
list
[
int
],
sampling_params
:
vllm
.
SamplingParams
self
,
prompt_token_ids
:
list
[
int
],
sampling_params
:
vllm
.
SamplingParams
)
->
vllm
.
RequestOutput
:
)
->
tuple
[
vllm
.
RequestOutput
,
int
]:
finish_reason
=
"abort"
"""Generate a single request, setting the request pause flag once the
while
finish_reason
==
"abort"
:
token count reaches the threshold.
async
for
request_output
in
self
.
generate
(
{
"prompt_token_ids"
:
prompt_token_ids
},
Returns (output, pause_token_index). pause_token_index is the number
sampling_params
,
of tokens generated before the weight change, or -1 if no pause.
request_id
=
str
(
uuid
.
uuid4
()),
"""
pause_token_index
=
-
1
prev_token_count
=
0
async
for
request_output
in
self
.
generate
(
{
"prompt_token_ids"
:
prompt_token_ids
},
sampling_params
,
request_id
=
str
(
uuid
.
uuid4
()),
):
output
=
request_output
cur_token_count
=
len
(
output
.
outputs
[
0
].
token_ids
)
if
(
cur_token_count
>=
PAUSE_TOKEN_THRESHOLD
and
not
self
.
_request_pause_flag
):
):
output
=
request_output
self
.
_request_pause_flag
=
True
finish_reason
=
output
.
outputs
[
0
].
finish_reason
if
self
.
_generation_paused
and
pause_token_index
==
-
1
:
if
finish_reason
==
"abort"
:
pause_token_index
=
prev_token_count
print
(
prev_token_count
=
cur_token_count
f
"ABORT, prompt_token_ids:
{
prompt_token_ids
}
, "
return
output
,
pause_token_index
f
"generated token_ids:
{
list
(
output
.
outputs
[
0
].
token_ids
)
}
"
)
async
def
pause_after_n_tokens
(
self
):
prompt_token_ids
=
prompt_token_ids
+
list
(
output
.
outputs
[
0
].
token_ids
)
"""Wait for any request to set the pause flag, then pause."""
return
output
while
not
self
.
_request_pause_flag
:
await
asyncio
.
sleep
(
0
)
await
super
().
pause_generation
(
mode
=
"keep"
)
await
asyncio
.
sleep
(
0.2
)
self
.
_generation_paused
=
True
@
ray
.
remote
(
num_gpus
=
1
)
@
ray
.
remote
(
num_gpus
=
1
)
...
@@ -95,6 +112,14 @@ class TrainModel:
...
@@ -95,6 +112,14 @@ class TrainModel:
"""Ray actor that wraps the training model on a dedicated GPU."""
"""Ray actor that wraps the training model on a dedicated GPU."""
def
__init__
(
self
,
model_name
:
str
):
def
__init__
(
self
,
model_name
:
str
):
from
vllm.model_executor.layers.batch_invariant
import
(
init_batch_invariance
,
)
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
# need to init all env vars for batch invariance which affect nccl ops
init_batch_invariance
(
AttentionBackendEnum
.
FLASH_ATTN
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
dtype
=
torch
.
bfloat16
model_name
,
dtype
=
torch
.
bfloat16
).
to
(
"cuda:0"
)
).
to
(
"cuda:0"
)
...
@@ -133,70 +158,80 @@ class TrainModel:
...
@@ -133,70 +158,80 @@ class TrainModel:
packed
=
packed
,
packed
=
packed
,
)
)
@
torch
.
inference_mode
()
# Initialize Ray and set the visible devices. The vLLM engine will
def
generate
(
self
,
token_ids
:
list
[
int
],
max_new_tokens
:
int
)
->
list
[
int
]:
# be placed on GPUs 1 and 2.
"""Greedy-decode max_new_tokens from the given context."""
ray
.
init
()
input_ids
=
torch
.
tensor
([
token_ids
],
device
=
"cuda:0"
)
output
=
self
.
model
.
generate
(
input_ids
,
max_new_tokens
=
max_new_tokens
,
do_sample
=
False
,
)
new_token_ids
=
output
[
0
,
len
(
token_ids
)
:].
tolist
()
return
new_token_ids
ray
.
init
(
runtime_env
=
{
"env_vars"
:
{
# enable batch invariance for deterministic outputs
"VLLM_BATCH_INVARIANT"
:
"1"
,
# prevent ray from setting CUDA_VISIBLE_DEVICES
"RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR"
:
"1"
,
}
}
)
# Launch the training model actor. Ray's resource scheduler will allocate
# Launch the training model actor. Ray's resource scheduler will allocate
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
train_model
=
TrainModel
.
remote
(
MODEL_NAME
)
train_model
=
TrainModel
.
remote
(
MODEL_NAME_V2
)
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/placement-groups.html
pg_inference
=
placement_group
([{
"GPU"
:
1
,
"CPU"
:
0
}]
*
2
)
ray
.
get
(
pg_inference
.
ready
())
scheduling_inference
=
PlacementGroupSchedulingStrategy
(
placement_group
=
pg_inference
,
placement_group_capture_child_tasks
=
True
,
placement_group_bundle_index
=
0
,
)
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
# start-up latency.
# start-up latency.
# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights)
# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates
# are now native to vLLM workers.
# its own placement groups internally for each DP rank, so we must NOT
# create an outer placement group (it would reserve GPUs and hide them
# from the internal DP resource check).
llm
=
ray
.
remote
(
llm
=
ray
.
remote
(
num_cpus
=
0
,
num_cpus
=
0
,
num_gpus
=
0
,
num_gpus
=
0
,
scheduling_strategy
=
scheduling_inference
,
)(
MyLLM
).
remote
(
)(
MyLLM
).
remote
(
model
=
MODEL_NAME
,
model
=
MODEL_NAME
_V1
,
enforce_eager
=
True
,
enforce_eager
=
True
,
tensor_parallel_size
=
2
,
max_model_len
=
819
2
,
distributed_executor_backend
=
"ray"
,
distributed_executor_backend
=
"ray"
,
load_format
=
"dummy"
,
attention_backend
=
"FLASH_ATTN"
,
gpu_memory_utilization
=
0.75
,
weight_transfer_config
=
WeightTransferConfig
(
backend
=
"nccl"
),
weight_transfer_config
=
WeightTransferConfig
(
backend
=
"nccl"
),
)
)
# Generate text from the prompts.
PROMPTS
=
[
prompts
=
[
"My name is"
,
"The president of the United States is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The capital of France is"
,
"The future of AI is"
,
"The largest ocean on Earth is"
,
"The speed of light in a vacuum is"
,
"The chemical formula for water is"
,
"The tallest mountain in the world is"
,
"The first person to walk on the moon was"
,
"The Great Wall of China was built to"
,
"Photosynthesis is the process by which"
,
"The theory of general relativity was proposed by"
,
"The boiling point of water at sea level is"
,
"The largest planet in our solar system is"
,
"DNA stands for deoxyribonucleic acid and it"
,
]
]
# Tokenize prompts to token IDs
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_NAME_V1
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
MODEL_NAME
)
batch_prompt_token_ids
=
[
prompt_token_ids_list
=
[
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
False
)
for
prompt
in
PROMPTS
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
False
)
for
prompt
in
prompts
]
]
sampling_params
=
[
SamplingParams
(
temperature
=
0
,
max_tokens
=
2
),
SamplingParams
(
temperature
=
0
,
max_tokens
=
32
),
SamplingParams
(
temperature
=
0
,
max_tokens
=
32
),
SamplingParams
(
temperature
=
0
,
max_tokens
=
32
),
]
# Set up the communication channel between the training process and the
# Set up the communication channel between the training process and the
# inference engine.
# inference engine.
master_address
,
master_port
=
ray
.
get
(
train_model
.
get_master_address_and_port
.
remote
())
master_address
,
master_port
=
ray
.
get
(
train_model
.
get_master_address_and_port
.
remote
())
world_size
=
3
# 1 trainer +
2
inference worker
s (tensor_parallel_size=2)
world_size
=
2
# 1 trainer +
1
inference worker
inference_handle
=
llm
.
init_weight_transfer_engine
.
remote
(
inference_handle
=
llm
.
init_weight_transfer_engine
.
remote
(
WeightTransferInitRequest
(
WeightTransferInitRequest
(
init_info
=
asdict
(
init_info
=
asdict
(
...
@@ -215,22 +250,28 @@ train_handle = train_model.init_weight_transfer_group.remote(world_size)
...
@@ -215,22 +250,28 @@ train_handle = train_model.init_weight_transfer_group.remote(world_size)
ray
.
get
([
train_handle
,
inference_handle
])
ray
.
get
([
train_handle
,
inference_handle
])
generation_futures
=
[
N_NEW_TOKENS
=
100
llm
.
generate_with_retry
.
remote
(
prompt_token_ids
,
params
)
for
prompt_token_ids
,
params
in
zip
(
prompt_token_ids_list
,
sampling_params
)
]
finished
,
pending
=
ray
.
wait
(
generation_futures
,
num_returns
=
1
)
# Collect weight metadata once
names
,
dtype_names
,
shapes
=
ray
.
get
(
train_model
.
get_weight_metadata
.
remote
())
# Pause generation in preparation for weight sync
# ── Phase 1: concurrent requests with weight sync ───────────────────
ray
.
get
(
llm
.
pause_generation
.
remote
(
wait_for_inflight_requests
=
False
))
print
(
f
"
\n
{
'='
*
50
}
"
)
print
(
f
"Prompts (
{
len
(
PROMPTS
)
}
):"
)
for
p
in
PROMPTS
:
print
(
f
" -
{
p
!
r
}
"
)
print
(
f
"
{
'='
*
50
}
"
)
# Synchronize the updated weights to the inference engine using batched API.
sampling_params
=
SamplingParams
(
# Collect all weight metadata from the training actor
temperature
=
0
,
max_tokens
=
PAUSE_TOKEN_THRESHOLD
+
N_NEW_TOKENS
names
,
dtype_names
,
shapes
=
ray
.
get
(
train_model
.
get_weight_metadata
.
remote
())
)
gen_futures
=
[
llm
.
do_generate
.
remote
(
ptids
,
sampling_params
)
for
ptids
in
batch_prompt_token_ids
]
ray
.
get
(
llm
.
pause_after_n_tokens
.
remote
())
# Issue update_weights call with NCCL-specific update info
# packed=True enables efficient batched tensor broadcasting
inference_handle
=
llm
.
update_weights
.
remote
(
inference_handle
=
llm
.
update_weights
.
remote
(
WeightTransferUpdateRequest
(
WeightTransferUpdateRequest
(
update_info
=
asdict
(
update_info
=
asdict
(
...
@@ -243,41 +284,76 @@ inference_handle = llm.update_weights.remote(
...
@@ -243,41 +284,76 @@ inference_handle = llm.update_weights.remote(
)
)
)
)
)
)
# Broadcast all weights from trainer using the weight transfer API
train_handle
=
train_model
.
broadcast_weights
.
remote
(
packed
=
True
)
train_handle
=
train_model
.
broadcast_weights
.
remote
(
packed
=
True
)
ray
.
get
([
train_handle
,
inference_handle
])
ray
.
get
([
train_handle
,
inference_handle
])
# Resume generation since weight sync is complete
ray
.
get
(
llm
.
resume_generation
.
remote
())
ray
.
get
(
llm
.
resume_generation
.
remote
())
results
=
ray
.
get
(
gen_futures
)
for
i
,
(
output
,
pause_idx
)
in
enumerate
(
results
):
all_token_ids
=
list
(
output
.
outputs
[
0
].
token_ids
)
before_text
=
tokenizer
.
decode
(
all_token_ids
[:
pause_idx
])
after_text
=
tokenizer
.
decode
(
all_token_ids
[
pause_idx
:])
print
(
f
"
\n
Request
{
i
}
(
{
PROMPTS
[
i
]
!
r
}
):"
)
print
(
f
" Old weights (
{
pause_idx
}
tokens):
{
before_text
!
r
}
"
)
n_after
=
len
(
all_token_ids
)
-
pause_idx
print
(
f
" New weights (
{
n_after
}
tokens):
{
after_text
!
r
}
"
)
# ── Phase 2: validate with a fresh V2 vLLM instance ────────────────
print
(
f
"
\n
{
'='
*
50
}
"
)
print
(
"VALIDATION: comparing weight-synced vLLM with fresh V2 instance"
)
print
(
f
"
{
'='
*
50
}
"
)
ray
.
get
(
llm
.
shutdown
.
remote
())
ray
.
kill
(
llm
)
ray
.
kill
(
train_model
)
llm_v2
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
0
,
)(
MyLLM
).
remote
(
model
=
MODEL_NAME_V2
,
enforce_eager
=
True
,
max_model_len
=
8192
,
gpu_memory_utilization
=
0.75
,
distributed_executor_backend
=
"ray"
,
attention_backend
=
"FLASH_ATTN"
,
)
val_futures
=
[
llm_v2
.
do_generate
.
remote
(
list
(
output
.
prompt_token_ids
)
+
list
(
output
.
outputs
[
0
].
token_ids
)[:
pause_idx
],
SamplingParams
(
temperature
=
0
,
max_tokens
=
len
(
output
.
outputs
[
0
].
token_ids
)
-
pause_idx
),
)
for
output
,
pause_idx
in
results
]
val_results
=
ray
.
get
(
val_futures
)
all_pass
=
True
for
i
,
((
output
,
pause_idx
),
(
val_output
,
_
))
in
enumerate
(
zip
(
results
,
val_results
)):
expected
=
list
(
output
.
outputs
[
0
].
token_ids
)[
pause_idx
:]
actual
=
list
(
val_output
.
outputs
[
0
].
token_ids
)
match
=
actual
==
expected
if
match
:
print
(
f
" [PASS]
{
PROMPTS
[
i
]
!
r
}
"
)
else
:
all_pass
=
False
print
(
f
" [FAIL]
{
PROMPTS
[
i
]
!
r
}
"
)
print
(
f
" weight-synced vLLM:
{
tokenizer
.
decode
(
expected
)
!
r
}
"
)
print
(
f
" V2 vLLM:
{
tokenizer
.
decode
(
actual
)
!
r
}
"
)
for
j
,
(
e
,
a
)
in
enumerate
(
zip
(
expected
,
actual
)):
if
e
!=
a
:
print
(
f
" first divergence at output token
{
j
}
: "
f
"expected
{
e
}
(
{
tokenizer
.
decode
([
e
])
!
r
}
) vs "
f
"actual
{
a
}
(
{
tokenizer
.
decode
([
a
])
!
r
}
)"
)
break
# Get outputs separately - finished completed before pause, pending were paused/resumed
ray
.
get
(
llm_v2
.
shutdown
.
remote
())
finished_outputs
=
ray
.
get
(
finished
)
ray
.
kill
(
llm_v2
)
pending_outputs
=
ray
.
get
(
pending
)
assert
all_pass
,
"Some prompts failed validation, see above for details"
print
(
"="
*
50
)
# Requests that finished before the pause: all generation used original weights
print
(
"-"
*
50
)
print
(
"Requests that completed BEFORE weight change:"
)
print
(
"-"
*
50
)
for
output
in
finished_outputs
:
prompt_text
=
tokenizer
.
decode
(
output
.
prompt_token_ids
)
print
(
f
"Prompt:
{
prompt_text
!
r
}
"
)
print
(
f
"Generated (with original weights):
{
output
.
outputs
[
0
].
text
!
r
}
"
)
print
(
"-"
*
50
)
# Requests that were paused mid-generation: some text before, some after weight change
print
(
"Requests that were PAUSED and RESUMED after weight change:"
)
print
(
"-"
*
50
)
for
output
in
pending_outputs
:
# Decode the full prompt token IDs (original + generated before pause)
full_prompt_text
=
tokenizer
.
decode
(
output
.
prompt_token_ids
)
# Find the original prompt by checking which one this output started with
original_prompt
=
next
(
p
for
p
in
prompts
if
full_prompt_text
.
startswith
(
p
))
# output.prompt_token_ids contains original prompt + tokens generated before pause
# output.outputs[0].text is what was generated after resuming with new weights
text_before_pause
=
full_prompt_text
[
len
(
original_prompt
)
:]
text_after_pause
=
output
.
outputs
[
0
].
text
print
(
f
"Original prompt:
{
original_prompt
!
r
}
"
)
print
(
f
"Generated before weight change:
{
text_before_pause
!
r
}
"
)
print
(
f
"Generated after weight change:
{
text_after_pause
!
r
}
"
)
print
(
"-"
*
50
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment