Unverified Commit c1858b7e authored by Aaron Hao's avatar Aaron Hao Committed by GitHub
Browse files

[Feat][RL][1/2] Native Weight Syncing API: NCCL (#31943)


Signed-off-by: default avatarahao-anyscale <ahao@anyscale.com>
Signed-off-by: default avatarAaron Hao <ahao@anyscale.com>
Co-authored-by: default avatarSumanthRH <sumanthrh99@gmail.com>
parent 82914d2a
...@@ -233,6 +233,7 @@ steps: ...@@ -233,6 +233,7 @@ steps:
- tests/compile/fullgraph/test_basic_correctness.py - tests/compile/fullgraph/test_basic_correctness.py
- examples/offline_inference/rlhf.py - examples/offline_inference/rlhf.py
- examples/offline_inference/rlhf_colocate.py - examples/offline_inference/rlhf_colocate.py
- examples/offline_inference/new_weight_syncing/
- tests/examples/offline_inference/data_parallel.py - tests/examples/offline_inference/data_parallel.py
- tests/v1/distributed - tests/v1/distributed
- tests/v1/engine/test_engine_core_client.py - tests/v1/engine/test_engine_core_client.py
...@@ -268,10 +269,16 @@ steps: ...@@ -268,10 +269,16 @@ steps:
- pytest -v -s distributed/test_symm_mem_allreduce.py - pytest -v -s distributed/test_symm_mem_allreduce.py
# TODO: create a dedicated test section for multi-GPU example tests # TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests # when we have multiple distributed example tests
# OLD rlhf examples
- pushd ../examples/offline_inference - pushd ../examples/offline_inference
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- popd - popd
# NEW rlhf examples
- pushd ../examples/offline_inference/new_weight_syncing
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
- popd
- label: Distributed Tests (8 GPUs) # 4min - label: Distributed Tests (8 GPUs) # 4min
timeout_in_minutes: 10 timeout_in_minutes: 10
......
...@@ -206,6 +206,7 @@ steps: ...@@ -206,6 +206,7 @@ steps:
- tests/compile/fullgraph/test_basic_correctness.py - tests/compile/fullgraph/test_basic_correctness.py
- examples/offline_inference/rlhf.py - examples/offline_inference/rlhf.py
- examples/offline_inference/rlhf_colocate.py - examples/offline_inference/rlhf_colocate.py
- examples/offline_inference/new_weight_syncing/
- tests/examples/offline_inference/data_parallel.py - tests/examples/offline_inference/data_parallel.py
- tests/v1/distributed - tests/v1/distributed
- tests/v1/engine/test_engine_core_client.py - tests/v1/engine/test_engine_core_client.py
...@@ -240,10 +241,16 @@ steps: ...@@ -240,10 +241,16 @@ steps:
- pytest -v -s distributed/test_symm_mem_allreduce.py - pytest -v -s distributed/test_symm_mem_allreduce.py
# TODO: create a dedicated test section for multi-GPU example tests # TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests # when we have multiple distributed example tests
# OLD rlhf examples
- pushd ../examples/offline_inference - pushd ../examples/offline_inference
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- popd - popd
# NEW rlhf examples
- pushd ../examples/offline_inference/new_weight_syncing
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
- popd
- label: Distributed Tests (8 GPUs) # 4min - label: Distributed Tests (8 GPUs) # 4min
timeout_in_minutes: 10 timeout_in_minutes: 10
...@@ -1146,6 +1153,8 @@ steps: ...@@ -1146,6 +1153,8 @@ steps:
- pytest -v -s distributed/test_shm_broadcast.py - pytest -v -s distributed/test_shm_broadcast.py
- pytest -v -s distributed/test_shm_buffer.py - pytest -v -s distributed/test_shm_buffer.py
- pytest -v -s distributed/test_shm_storage.py - pytest -v -s distributed/test_shm_storage.py
- pytest -v -s distributed/test_packed_tensor.py
- pytest -v -s distributed/test_weight_transfer.py
- label: 2 Node Tests (4 GPUs in total) # 16min - label: 2 Node Tests (4 GPUs in total) # 16min
timeout_in_minutes: 30 timeout_in_minutes: 30
......
...@@ -62,6 +62,7 @@ steps: ...@@ -62,6 +62,7 @@ steps:
- tests/compile/fullgraph/test_basic_correctness.py - tests/compile/fullgraph/test_basic_correctness.py
- examples/offline_inference/rlhf.py - examples/offline_inference/rlhf.py
- examples/offline_inference/rlhf_colocate.py - examples/offline_inference/rlhf_colocate.py
- examples/offline_inference/new_weight_syncing/
- tests/examples/offline_inference/data_parallel.py - tests/examples/offline_inference/data_parallel.py
- tests/v1/distributed - tests/v1/distributed
- tests/v1/engine/test_engine_core_client.py - tests/v1/engine/test_engine_core_client.py
...@@ -96,9 +97,14 @@ steps: ...@@ -96,9 +97,14 @@ steps:
- pytest -v -s distributed/test_symm_mem_allreduce.py - pytest -v -s distributed/test_symm_mem_allreduce.py
# TODO: create a dedicated test section for multi-GPU example tests # TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests # when we have multiple distributed example tests
# OLD rlhf examples
- cd ../examples/offline_inference - cd ../examples/offline_inference
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
# NEW rlhf examples
- cd new_weight_syncing
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py
- label: Distributed Tests (8 GPUs)(H100) - label: Distributed Tests (8 GPUs)(H100)
timeout_in_minutes: 10 timeout_in_minutes: 10
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning using vLLM and Ray,
with native weight syncing APIs at engine instance.
The script separates training and inference workloads onto distinct GPUs
so that Ray can manage process placement and inter-process communication.
A Hugging Face Transformer model occupies one GPU for training, whereas a
2x tensor-parallel vLLM inference engine occupies two GPUs.
The example performs the following steps:
* Load the training model on one gpu (scheduled via ray)
* Initialize the inference model with dummy weights across
two gpus using vLLM's tensor parallelism and Ray placement groups.
* Generate gibberish from a list of prompts using the randomly initialized
inference engine.
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using a Ray collective RPC group.
* Generating from the list of prompts after weight sync should result
in sensible outputs.
This example assumes a single-node cluster with three GPUs, but Ray
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
"""
import os
import ray
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLWeightTransferEngine,
)
from vllm.utils.network_utils import get_ip, get_open_port
MODEL_NAME = "facebook/opt-125m"
# MODEL_NAME = "inference-optimization/Qwen3-0.6B-W4A16-G128"
class MyLLM(LLM):
"""Configure the vLLM worker for Ray placement group execution."""
def __init__(self, *args, **kwargs):
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1"
super().__init__(*args, **kwargs)
@ray.remote(num_gpus=1)
class TrainModel:
"""Ray actor that wraps the training model on a dedicated GPU."""
def __init__(self, model_name: str):
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
).to("cuda:0")
self.port = get_open_port()
self.master_address = get_ip()
def get_master_address_and_port(self):
return self.master_address, self.port
def get_weight_metadata(self):
"""Return weight names, dtypes, and shapes for weight transfer."""
names = []
dtype_names = []
shapes = []
for name, p in self.model.named_parameters():
names.append(name)
dtype_names.append(str(p.dtype).split(".")[-1])
shapes.append(list(p.shape))
return names, dtype_names, shapes
def init_weight_transfer_group(self, world_size):
"""Initialize the NCCL process group for weight transfer."""
self.model_update_group = NCCLWeightTransferEngine.trainer_init(
dict(
master_address=self.master_address,
master_port=self.port,
world_size=world_size,
),
)
def broadcast_weights(self, packed: bool = True):
"""Broadcast weights to the inference engine."""
NCCLWeightTransferEngine.trainer_send_weights(
iterator=self.model.named_parameters(),
group=self.model_update_group,
packed=packed,
)
# Initialize Ray and set the visible devices. The vLLM engine will
# be placed on GPUs 1 and 2.
ray.init()
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/placement-groups.html
# Launch the training model actor. Ray's resource scheduler will allocate
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
train_model = TrainModel.remote(MODEL_NAME)
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group=pg_inference,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=0,
)
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
# start-up latency.
# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights)
# are now native to vLLM workers.
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=scheduling_inference,
)(MyLLM).remote(
model=MODEL_NAME,
enforce_eager=True,
tensor_parallel_size=2,
data_parallel_size=1,
distributed_executor_backend="ray",
weight_transfer_config=WeightTransferConfig(backend="nccl"),
load_format="dummy",
quantization="fp8",
)
# Generate text from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
# Generate text with the initial model. The output is expected to be nonsense
# because the weights are randomly initialized.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Set up the communication channel between the training process and the
# inference engine.
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
world_size = ray.get(llm.get_world_size.remote()) + 1 # +1 for the trainer
inference_handle = llm.init_weight_transfer_engine.remote(
dict(
init_info=dict(
master_address=master_address,
master_port=master_port,
rank_offset=1,
world_size=world_size,
)
)
)
# Initialize weight transfer group on both the training actor and inference engine
train_handle = train_model.init_weight_transfer_group.remote(world_size)
ray.get([train_handle, inference_handle])
# Synchronize the updated weights to the inference engine using batched API.
# Collect all weight metadata from the training actor
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
# Issue update_weights call with NCCL-specific update info
# packed=True enables efficient batched tensor broadcasting
inference_handle = llm.update_weights.remote(
dict(
update_info=dict(
names=names,
dtype_names=dtype_names,
shapes=shapes,
packed=True,
)
)
)
# Broadcast all weights from trainer using the weight transfer API
train_handle = train_model.broadcast_weights.remote(packed=True)
ray.get([train_handle, inference_handle])
# Generate text with the updated model. The output is expected to be normal
# because the weights are updated.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs_updated:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates async reinforcement learning using vLLM and Ray,
with native weight syncing APIs at engine instance.
The script separates training and inference workloads onto distinct GPUs
so that Ray can manage process placement and inter-process communication.
A Hugging Face Transformer model occupies one GPU for training, whereas a
2x tensor-parallel vLLM inference engine occupies two GPUs.
The example performs the following steps:
* Load the training model on one gpu (scheduled via ray)
* Initialize the inference model with dummy weights across
two gpus using vLLM's tensor parallelism and Ray placement groups.
* Generate gibberish from a list of prompts using the randomly initialized
inference engine.
* Pause generation once generation completes for one sequence
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using a Ray collective RPC group.
* Resume generation and print out the results
This example assumes a single-node cluster with three GPUs, but Ray
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
"""
import os
import uuid
from dataclasses import asdict
import ray
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoModelForCausalLM, AutoTokenizer
import vllm
from vllm import SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLWeightTransferEngine,
NCCLWeightTransferInitInfo,
NCCLWeightTransferUpdateInfo,
)
from vllm.utils.network_utils import get_ip, get_open_port
from vllm.v1.executor import Executor
MODEL_NAME = "facebook/opt-125m"
class MyLLM(vllm.AsyncLLMEngine):
"""Configure the vLLM worker for Ray placement group execution."""
def __init__(self, **kwargs):
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1"
engine_args = vllm.AsyncEngineArgs(**kwargs)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
super().__init__(
vllm_config=vllm_config,
executor_class=executor_class,
log_requests=engine_args.enable_log_requests,
log_stats=not engine_args.disable_log_stats,
)
async def generate_with_retry(
self, prompt_token_ids: list[int], sampling_params: vllm.SamplingParams
) -> vllm.RequestOutput:
finish_reason = "abort"
while finish_reason == "abort":
async for request_output in self.generate(
{"prompt_token_ids": prompt_token_ids},
sampling_params,
request_id=str(uuid.uuid4()),
):
output = request_output
finish_reason = output.outputs[0].finish_reason
if finish_reason == "abort":
print(
f"ABORT, prompt_token_ids: {prompt_token_ids}, "
f"generated token_ids: {list(output.outputs[0].token_ids)}"
)
prompt_token_ids = prompt_token_ids + list(output.outputs[0].token_ids)
return output
@ray.remote(num_gpus=1)
class TrainModel:
"""Ray actor that wraps the training model on a dedicated GPU."""
def __init__(self, model_name: str):
self.model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16
).to("cuda:0")
self.port = get_open_port()
self.master_address = get_ip()
def get_master_address_and_port(self):
return self.master_address, self.port
def get_weight_metadata(self):
"""Return weight names, dtypes, and shapes for weight transfer."""
names = []
dtype_names = []
shapes = []
for name, p in self.model.named_parameters():
names.append(name)
dtype_names.append(str(p.dtype).split(".")[-1])
shapes.append(list(p.shape))
return names, dtype_names, shapes
def init_weight_transfer_group(self, world_size):
"""Initialize the NCCL process group for weight transfer."""
self.model_update_group = NCCLWeightTransferEngine.trainer_init(
dict(
master_address=self.master_address,
master_port=self.port,
world_size=world_size,
),
)
def broadcast_weights(self, packed: bool = True):
"""Broadcast weights to the inference engine."""
NCCLWeightTransferEngine.trainer_send_weights(
iterator=self.model.named_parameters(),
group=self.model_update_group,
packed=packed,
)
# Initialize Ray and set the visible devices. The vLLM engine will
# be placed on GPUs 1 and 2.
ray.init()
# Launch the training model actor. Ray's resource scheduler will allocate
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
train_model = TrainModel.remote(MODEL_NAME)
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/placement-groups.html
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group=pg_inference,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=0,
)
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
# start-up latency.
# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights)
# are now native to vLLM workers.
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=scheduling_inference,
)(MyLLM).remote(
model=MODEL_NAME,
enforce_eager=True,
tensor_parallel_size=2,
distributed_executor_backend="ray",
load_format="dummy",
weight_transfer_config=WeightTransferConfig(backend="nccl"),
)
# Generate text from the prompts.
prompts = [
"My name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Tokenize prompts to token IDs
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
prompt_token_ids_list = [
tokenizer.encode(prompt, add_special_tokens=False) for prompt in prompts
]
sampling_params = [
SamplingParams(temperature=0, max_tokens=2),
SamplingParams(temperature=0, max_tokens=32),
SamplingParams(temperature=0, max_tokens=32),
SamplingParams(temperature=0, max_tokens=32),
]
# Set up the communication channel between the training process and the
# inference engine.
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
world_size = 3 # 1 trainer + 2 inference workers (tensor_parallel_size=2)
inference_handle = llm.init_weight_transfer_engine.remote(
WeightTransferInitRequest(
init_info=asdict(
NCCLWeightTransferInitInfo(
master_address=master_address,
master_port=master_port,
rank_offset=1,
world_size=world_size,
)
)
)
)
# Initialize weight transfer group on both the training actor and inference engine
train_handle = train_model.init_weight_transfer_group.remote(world_size)
ray.get([train_handle, inference_handle])
generation_futures = [
llm.generate_with_retry.remote(prompt_token_ids, params)
for prompt_token_ids, params in zip(prompt_token_ids_list, sampling_params)
]
finished, pending = ray.wait(generation_futures, num_returns=1)
# Pause generation in preparation for weight sync
ray.get(llm.pause_generation.remote(wait_for_inflight_requests=False))
# Synchronize the updated weights to the inference engine using batched API.
# Collect all weight metadata from the training actor
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
# Issue update_weights call with NCCL-specific update info
# packed=True enables efficient batched tensor broadcasting
inference_handle = llm.update_weights.remote(
WeightTransferUpdateRequest(
update_info=asdict(
NCCLWeightTransferUpdateInfo(
names=names,
dtype_names=dtype_names,
shapes=shapes,
packed=True,
)
)
)
)
# Broadcast all weights from trainer using the weight transfer API
train_handle = train_model.broadcast_weights.remote(packed=True)
ray.get([train_handle, inference_handle])
# Resume generation since weight sync is complete
ray.get(llm.resume_generation.remote())
# Get outputs separately - finished completed before pause, pending were paused/resumed
finished_outputs = ray.get(finished)
pending_outputs = ray.get(pending)
# Requests that finished before the pause: all generation used original weights
print("-" * 50)
print("Requests that completed BEFORE weight change:")
print("-" * 50)
for output in finished_outputs:
prompt_text = tokenizer.decode(output.prompt_token_ids)
print(f"Prompt: {prompt_text!r}")
print(f"Generated (with original weights): {output.outputs[0].text!r}")
print("-" * 50)
# Requests that were paused mid-generation: some text before, some after weight change
print("Requests that were PAUSED and RESUMED after weight change:")
print("-" * 50)
for output in pending_outputs:
# Decode the full prompt token IDs (original + generated before pause)
full_prompt_text = tokenizer.decode(output.prompt_token_ids)
# Find the original prompt by checking which one this output started with
original_prompt = next(p for p in prompts if full_prompt_text.startswith(p))
# output.prompt_token_ids contains original prompt + tokens generated before pause
# output.outputs[0].text is what was generated after resuming with new weights
text_before_pause = full_prompt_text[len(original_prompt) :]
text_after_pause = output.outputs[0].text
print(f"Original prompt: {original_prompt!r}")
print(f"Generated before weight change: {text_before_pause!r}")
print(f"Generated after weight change: {text_after_pause!r}")
print("-" * 50)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM
via HTTP API, with native weight syncing APIs.
Unlike rlhf.py which creates a vLLM instance programmatically, this script
assumes you have already started a vLLM server using `vllm serve`. It uses:
- OpenAI-compatible API for inference requests
- HTTP endpoints for weight transfer control plane
- NCCL for actual weight data transfer
Prerequisites:
Start a vLLM server with weight transfer enabled:
$ VLLM_SERVER_DEV_MODE=1 vllm serve facebook/opt-125m \
--enforce-eager \
--weight-transfer-config '{"backend": "nccl"}' \
--load-format dummy
Then run this script:
$ python rlhf_http.py
The example performs the following steps:
* Load the training model on GPU 0.
* Generate text using the vLLM server via OpenAI-compatible API. The output
is expected to be nonsense because the server is initialized with dummy weights.
* Initialize weight transfer via HTTP endpoint.
* Broadcast the real weights from the training model to the vLLM server
using NCCL.
* Generate text again to show normal output after the weight update.
"""
import requests
import torch
from openai import OpenAI
from transformers import AutoModelForCausalLM
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLWeightTransferEngine,
)
from vllm.utils.network_utils import get_ip, get_open_port
BASE_URL = "http://localhost:8000"
MODEL_NAME = "facebook/opt-125m"
def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]:
"""Generate completions using the OpenAI-compatible API."""
results = []
for prompt in prompts:
response = client.completions.create(
model=model,
prompt=prompt,
max_tokens=32,
temperature=0,
)
results.append(response.choices[0].text)
return results
def init_weight_transfer_engine(
base_url: str,
master_address: str,
master_port: int,
rank_offset: int,
world_size: int,
) -> None:
"""Initialize weight transfer via HTTP endpoint."""
url = f"{base_url}/init_weight_transfer_engine"
payload = {
"init_info": dict(
master_address=master_address,
master_port=master_port,
rank_offset=rank_offset,
world_size=world_size,
)
}
response = requests.post(url, json=payload, timeout=60)
response.raise_for_status()
def update_weights(
base_url: str,
names: list[str],
dtype_names: list[str],
shapes: list[list[int]],
packed: bool = False,
) -> None:
"""Update weights via HTTP endpoint."""
url = f"{base_url}/update_weights"
payload = {
"update_info": dict(
names=names,
dtype_names=dtype_names,
shapes=shapes,
packed=packed,
)
}
response = requests.post(url, json=payload, timeout=300)
response.raise_for_status()
def pause_generation(base_url: str) -> None:
"""Pause generation via HTTP endpoint."""
url = f"{base_url}/pause"
response = requests.post(url, timeout=60)
response.raise_for_status()
def resume_generation(base_url: str) -> None:
"""Resume generation via HTTP endpoint."""
url = f"{base_url}/resume"
response = requests.post(url, timeout=60)
response.raise_for_status()
def get_world_size(base_url: str) -> int:
"""Get world size from the vLLM server."""
url = f"{base_url}/get_world_size"
response = requests.get(url, timeout=10)
response.raise_for_status()
return response.json()["world_size"]
def main():
# Get the inference world size from the vLLM server
inference_world_size = get_world_size(BASE_URL)
world_size = inference_world_size + 1 # +1 for the trainer
device = f"cuda:{inference_world_size}"
torch.cuda.set_device(device)
# Load the training model
print(f"Loading training model: {MODEL_NAME}")
train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16)
train_model.to(device)
# Create OpenAI client pointing to the vLLM server
client = OpenAI(
base_url=f"{BASE_URL}/v1",
api_key="EMPTY", # vLLM doesn't require an API key by default
)
# Test prompts
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Generate text before weight update. The output is expected to be nonsense
# because the server is initialized with dummy weights.
print("-" * 50)
print("Generating text BEFORE weight update (expect nonsense):")
print("-" * 50)
outputs = generate_completions(client, MODEL_NAME, prompts)
for prompt, generated_text in zip(prompts, outputs):
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Set up the communication channel between the training process and the
# vLLM server. The trainer is rank 0, vLLM worker(s) start at rank_offset.
master_address = get_ip()
master_port = get_open_port()
rank_offset = 1
print(f"Initializing weight transfer: master={master_address}:{master_port}")
# Initialize weight transfer on vLLM server (this is async, server will
# wait for NCCL connection)
import threading
init_thread = threading.Thread(
target=init_weight_transfer_engine,
args=(BASE_URL, master_address, master_port, rank_offset, world_size),
)
init_thread.start()
# Initialize NCCL process group on trainer side
model_update_group = NCCLWeightTransferEngine.trainer_init(
dict(
master_address=master_address,
master_port=master_port,
world_size=world_size,
),
)
# Wait for init_weight_transfer_engine to complete
init_thread.join()
# Pause generation before weight sync
pause_generation(BASE_URL)
# Collect weight metadata for the update request
names = []
dtype_names = []
shapes = []
for name, p in train_model.named_parameters():
names.append(name)
dtype_names.append(str(p.dtype).split(".")[-1])
shapes.append(list(p.shape))
# Start the update_weights call in a separate thread since it will block
# waiting for NCCL broadcasts
# packed=True enables efficient batched tensor broadcasting
update_thread = threading.Thread(
target=update_weights,
args=(BASE_URL, names, dtype_names, shapes, True), # packed=True
)
update_thread.start()
# Broadcast all weights from trainer to vLLM workers
print("Broadcasting weights via NCCL...")
NCCLWeightTransferEngine.trainer_send_weights(
iterator=train_model.named_parameters(),
group=model_update_group,
packed=True,
)
# Wait for update_weights to complete
update_thread.join()
# Resume generation after weight sync
resume_generation(BASE_URL)
# Generate text after weight update. The output is expected to be normal
# because the real weights are now loaded.
print("-" * 50)
print("Generating text AFTER weight update:")
print("-" * 50)
outputs_updated = generate_completions(client, MODEL_NAME, prompts)
for prompt, generated_text in zip(prompts, outputs_updated):
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for packed tensor broadcasting functionality.
Unit tests for packed_broadcast_producer and packed_broadcast_consumer.
These utilities enable efficient batched tensor transfer over NCCL.
"""
import pytest
import torch
from vllm.distributed.weight_transfer.nccl_engine import NCCLWeightTransferUpdateInfo
from vllm.distributed.weight_transfer.packed_tensor import (
packed_broadcast_consumer,
packed_broadcast_producer,
)
class MockCommunicationGroup:
"""Mock communication group for testing producer broadcast operations."""
def __init__(self):
self.broadcasted_tensors: list[torch.Tensor] = []
self.broadcast_count = 0
self.device = torch.device("cuda:0")
def broadcast(self, tensor, src):
"""Mock broadcast that stores the tensor for later verification."""
self.broadcasted_tensors.append(tensor.clone())
self.broadcast_count += 1
class MockConsumerCommunicationGroup:
"""Mock communication group for consumer that returns pre-stored tensors."""
def __init__(self, tensors_to_return: list[torch.Tensor]):
self.tensors_to_return = tensors_to_return
self.current_index = 0
self.device = torch.device("cuda:0")
def broadcast(self, tensor, src):
"""Mock broadcast that fills the tensor with pre-stored data."""
if self.current_index < len(self.tensors_to_return):
tensor.copy_(self.tensors_to_return[self.current_index])
self.current_index += 1
def create_mock_model_params(
num_layers: int = 3,
dtype: torch.dtype = torch.float32,
) -> list[tuple[str, torch.Tensor]]:
"""Create mock model parameters for testing."""
params = []
for i in range(num_layers):
params.append((f"layer{i}.weight", torch.randn(10, 20, dtype=dtype)))
params.append((f"layer{i}.bias", torch.randn(10, dtype=dtype)))
return params
def create_state_dict_info(
params: list[tuple[str, torch.Tensor]],
) -> dict[str, tuple[tuple[int, ...], torch.dtype]]:
"""Create state dict info (name -> (shape, dtype)) from params."""
return {name: (tuple(tensor.shape), tensor.dtype) for name, tensor in params}
# --- Unit Tests: NCCLWeightTransferUpdateInfo packed field ---
class TestNCCLWeightTransferUpdateInfoPacked:
"""Test NCCLWeightTransferUpdateInfo dataclass packed field."""
def test_packed_default_false(self):
"""Test that packed defaults to False."""
info = NCCLWeightTransferUpdateInfo(
names=["layer.weight"],
dtype_names=["float32"],
shapes=[[10, 10]],
)
assert info.packed is False
def test_packed_can_be_set_true(self):
"""Test that packed can be set to True."""
info = NCCLWeightTransferUpdateInfo(
names=["layer.weight"],
dtype_names=["float32"],
shapes=[[10, 10]],
packed=True,
)
assert info.packed is True
# --- Unit Tests: packed_broadcast_producer ---
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestPackedBroadcastProducer:
"""Test packed_broadcast_producer function."""
def test_producer_broadcasts_tensors(self):
"""Test that producer broadcasts all tensors."""
params = create_mock_model_params()
params_cuda = [(name, tensor.cuda()) for name, tensor in params]
mock_group = MockCommunicationGroup()
# Use a small target size to force multiple batches
packed_broadcast_producer(
iterator=iter(params_cuda),
group=mock_group,
src=0,
post_iter_func=lambda x: x[1],
buffer_size_bytes=500,
)
# Should have broadcasted some tensors
assert mock_group.broadcast_count > 0
assert len(mock_group.broadcasted_tensors) > 0
def test_producer_single_large_tensor(self):
"""Test with a single tensor larger than target size."""
# Create a large tensor
large_tensor = torch.randn(1000, 1000, dtype=torch.float32).cuda()
params = [("large_weight", large_tensor)]
mock_group = MockCommunicationGroup()
# Small target size to force the tensor to exceed it
packed_broadcast_producer(
iterator=iter(params),
group=mock_group,
src=0,
post_iter_func=lambda x: x[1],
buffer_size_bytes=100,
)
# Should still broadcast the tensor (at least 1 broadcast)
assert mock_group.broadcast_count >= 1
assert len(mock_group.broadcasted_tensors) >= 1
# Verify the total broadcasted size matches the tensor
expected_size = large_tensor.numel() * large_tensor.element_size()
actual_size = sum(t.numel() for t in mock_group.broadcasted_tensors)
assert actual_size == expected_size
def test_producer_multiple_batches(self):
"""Test that tensors are properly batched when exceeding target size."""
# Create many small tensors
params = [
(f"weight_{i}", torch.randn(10, 10, dtype=torch.float32).cuda())
for i in range(20)
]
mock_group = MockCommunicationGroup()
# Small target size to force multiple batches
packed_broadcast_producer(
iterator=iter(params),
group=mock_group,
src=0,
post_iter_func=lambda x: x[1],
buffer_size_bytes=2000,
)
# Should have multiple broadcasts
assert mock_group.broadcast_count > 1
# Total size should match sum of all tensors
expected_total = sum(t.numel() * t.element_size() for _, t in params)
actual_total = sum(t.numel() for t in mock_group.broadcasted_tensors)
assert actual_total == expected_total
def test_producer_empty_iterator(self):
"""Test producer handles empty iterator gracefully."""
mock_group = MockCommunicationGroup()
packed_broadcast_producer(
iterator=iter([]),
group=mock_group,
src=0,
post_iter_func=lambda x: x[1],
buffer_size_bytes=1000,
)
# No broadcasts for empty iterator
assert mock_group.broadcast_count == 0
# --- Unit Tests: packed_broadcast_consumer ---
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestPackedBroadcastConsumer:
"""Test packed_broadcast_consumer function."""
def test_consumer_receives_tensors(self):
"""Test that consumer receives and unpacks tensors."""
params = create_mock_model_params()
params_cuda = [(name, tensor.cuda()) for name, tensor in params]
buffer_size = 2000
# First, run producer to get the broadcasted tensors
producer_group = MockCommunicationGroup()
packed_broadcast_producer(
iterator=iter(params_cuda),
group=producer_group,
src=0,
post_iter_func=lambda x: x[1],
buffer_size_bytes=buffer_size,
)
# Now run consumer with the broadcasted tensors
consumer_group = MockConsumerCommunicationGroup(
producer_group.broadcasted_tensors
)
state_dict_info = create_state_dict_info(params_cuda)
unpacked_tensors = {}
def post_unpack_func(tensor_list):
for name, tensor in tensor_list:
unpacked_tensors[name] = tensor.clone()
packed_broadcast_consumer(
iterator=iter(state_dict_info.items()),
group=consumer_group,
src=0,
post_unpack_func=post_unpack_func,
buffer_size_bytes=buffer_size,
)
# Verify all parameters were unpacked
assert len(unpacked_tensors) == len(params)
# Verify each tensor matches the original
for name, original_tensor in params_cuda:
assert name in unpacked_tensors
unpacked = unpacked_tensors[name]
assert unpacked.shape == original_tensor.shape
assert unpacked.dtype == original_tensor.dtype
assert torch.allclose(unpacked, original_tensor, rtol=1e-5, atol=1e-7)
# --- Integration Tests: Producer-Consumer Roundtrip ---
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestPackedBroadcastRoundtrip:
"""Test producer-consumer roundtrip behavior."""
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_roundtrip_different_dtypes(self, dtype):
"""Test roundtrip with different data types."""
params = create_mock_model_params(num_layers=2, dtype=dtype)
params_cuda = [(name, tensor.cuda()) for name, tensor in params]
buffer_size = 1000
producer_group = MockCommunicationGroup()
packed_broadcast_producer(
iterator=iter(params_cuda),
group=producer_group,
src=0,
post_iter_func=lambda x: x[1],
buffer_size_bytes=buffer_size,
)
consumer_group = MockConsumerCommunicationGroup(
producer_group.broadcasted_tensors
)
state_dict_info = create_state_dict_info(params_cuda)
unpacked_tensors = {}
def post_unpack_func(tensor_list):
for name, tensor in tensor_list:
unpacked_tensors[name] = tensor.clone()
packed_broadcast_consumer(
iterator=iter(state_dict_info.items()),
group=consumer_group,
src=0,
post_unpack_func=post_unpack_func,
buffer_size_bytes=buffer_size,
)
# Verify roundtrip preserves data
for name, original_tensor in params_cuda:
assert name in unpacked_tensors
unpacked = unpacked_tensors[name]
assert unpacked.dtype == dtype
assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6)
def test_roundtrip_mixed_dtypes(self):
"""Test roundtrip with mixed data types."""
# Create params with mixed dtypes
params = [
("layer1.weight", torch.randn(10, 20, dtype=torch.float32).cuda()),
("layer1.bias", torch.randn(10, dtype=torch.float16).cuda()),
("layer2.weight", torch.randn(20, 30, dtype=torch.bfloat16).cuda()),
]
buffer_size = 500
producer_group = MockCommunicationGroup()
packed_broadcast_producer(
iterator=iter(params),
group=producer_group,
src=0,
post_iter_func=lambda x: x[1],
buffer_size_bytes=buffer_size,
)
consumer_group = MockConsumerCommunicationGroup(
producer_group.broadcasted_tensors
)
state_dict_info = create_state_dict_info(params)
unpacked_tensors = {}
def post_unpack_func(tensor_list):
for name, tensor in tensor_list:
unpacked_tensors[name] = tensor.clone()
packed_broadcast_consumer(
iterator=iter(state_dict_info.items()),
group=consumer_group,
src=0,
post_unpack_func=post_unpack_func,
buffer_size_bytes=buffer_size,
)
# Verify all params roundtrip correctly with correct dtypes
for name, original_tensor in params:
assert name in unpacked_tensors
unpacked = unpacked_tensors[name]
assert unpacked.shape == original_tensor.shape
assert unpacked.dtype == original_tensor.dtype
assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6)
@pytest.mark.parametrize("target_size", [100, 1000, 10000, 100000])
def test_roundtrip_different_batch_sizes(self, target_size):
"""Test roundtrip with different target batch sizes."""
params = create_mock_model_params(num_layers=5)
params_cuda = [(name, tensor.cuda()) for name, tensor in params]
producer_group = MockCommunicationGroup()
packed_broadcast_producer(
iterator=iter(params_cuda),
group=producer_group,
src=0,
post_iter_func=lambda x: x[1],
buffer_size_bytes=target_size,
)
consumer_group = MockConsumerCommunicationGroup(
producer_group.broadcasted_tensors
)
state_dict_info = create_state_dict_info(params_cuda)
unpacked_tensors = {}
def post_unpack_func(tensor_list):
for name, tensor in tensor_list:
unpacked_tensors[name] = tensor.clone()
packed_broadcast_consumer(
iterator=iter(state_dict_info.items()),
group=consumer_group,
src=0,
post_unpack_func=post_unpack_func,
buffer_size_bytes=target_size,
)
# Verify all params roundtrip correctly
assert len(unpacked_tensors) == len(params)
for name, original_tensor in params_cuda:
assert name in unpacked_tensors
assert torch.allclose(
unpacked_tensors[name], original_tensor, rtol=1e-5, atol=1e-7
)
def test_roundtrip_non_contiguous_tensors(self):
"""Test roundtrip with non-contiguous tensors from the trainer."""
# Create non-contiguous tensors (simulating trainer outputs)
# Transposed tensors are non-contiguous
weight1 = torch.randn(20, 10, dtype=torch.float32).cuda().T
# Sliced tensors with step are non-contiguous
weight2 = torch.randn(40, 30, dtype=torch.float16).cuda()[::2, ::2]
# Permuted tensors are non-contiguous
weight3 = torch.randn(5, 10, 15, dtype=torch.bfloat16).cuda().permute(2, 0, 1)
params = [
("layer1.weight", weight1),
("layer2.weight", weight2),
("layer3.weight", weight3),
]
# Verify tensors are indeed non-contiguous
for name, tensor in params:
assert not tensor.is_contiguous(), f"{name} should be non-contiguous"
buffer_size = 500
producer_group = MockCommunicationGroup()
packed_broadcast_producer(
iterator=iter(params),
group=producer_group,
src=0,
post_iter_func=lambda x: x[1],
buffer_size_bytes=buffer_size,
)
consumer_group = MockConsumerCommunicationGroup(
producer_group.broadcasted_tensors
)
state_dict_info = create_state_dict_info(params)
unpacked_tensors = {}
def post_unpack_func(tensor_list):
for name, tensor in tensor_list:
unpacked_tensors[name] = tensor.clone()
packed_broadcast_consumer(
iterator=iter(state_dict_info.items()),
group=consumer_group,
src=0,
post_unpack_func=post_unpack_func,
buffer_size_bytes=buffer_size,
)
# Verify all non-contiguous params roundtrip correctly
for name, original_tensor in params:
assert name in unpacked_tensors
unpacked = unpacked_tensors[name]
assert unpacked.shape == original_tensor.shape
assert unpacked.dtype == original_tensor.dtype
assert torch.allclose(unpacked, original_tensor, rtol=1e-4, atol=1e-6)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for weight transfer engine backends.
Unit tests for engine classes (parsing, validation, registry).
Integration test for NCCL weight transfer between processes using Ray.
"""
from unittest.mock import MagicMock
import pytest
import ray
import torch
from vllm.config.parallel import ParallelConfig
from vllm.config.weight_transfer import WeightTransferConfig
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLWeightTransferEngine,
NCCLWeightTransferInitInfo,
NCCLWeightTransferUpdateInfo,
)
from vllm.utils.network_utils import get_open_port
def create_mock_parallel_config(
rank: int = 0,
world_size: int = 1,
dp_rank: int = 0,
) -> ParallelConfig:
"""Create a mock ParallelConfig for testing."""
config = MagicMock(spec=ParallelConfig)
config.rank = rank
config.world_size = world_size
config.data_parallel_rank = dp_rank
return config
# --- Unit Tests: NCCLWeightTransferUpdateInfo Validation ---
class TestNCCLWeightTransferUpdateInfoValidation:
"""Test NCCLWeightTransferUpdateInfo dataclass validation."""
def test_valid_update_info(self):
"""Test creating valid NCCLWeightTransferUpdateInfo."""
info = NCCLWeightTransferUpdateInfo(
names=["layer.weight", "layer.bias"],
dtype_names=["float32", "float32"],
shapes=[[10, 10], [10]],
)
assert info.names == ["layer.weight", "layer.bias"]
assert info.dtype_names == ["float32", "float32"]
assert info.shapes == [[10, 10], [10]]
def test_mismatched_dtype_names_raises(self):
"""Test that mismatched dtype_names length raises ValueError."""
with pytest.raises(ValueError, match="dtype_names"):
NCCLWeightTransferUpdateInfo(
names=["layer.weight", "layer.bias"],
dtype_names=["float32"], # Only one dtype
shapes=[[10, 10], [10]],
)
def test_mismatched_shapes_raises(self):
"""Test that mismatched shapes length raises ValueError."""
with pytest.raises(ValueError, match="shapes"):
NCCLWeightTransferUpdateInfo(
names=["layer.weight", "layer.bias"],
dtype_names=["float32", "float32"],
shapes=[[10, 10]], # Only one shape
)
def test_empty_lists_valid(self):
"""Test that empty lists are valid."""
info = NCCLWeightTransferUpdateInfo(
names=[],
dtype_names=[],
shapes=[],
)
assert len(info.names) == 0
# --- Unit Tests: Engine Parsing ---
class TestNCCLEngineParsing:
"""Test NCCLWeightTransferEngine parsing methods."""
def test_parse_init_info_valid(self):
"""Test parsing valid init info dict."""
config = WeightTransferConfig(backend="nccl")
parallel_config = create_mock_parallel_config()
engine = NCCLWeightTransferEngine(config, parallel_config)
init_info = engine.parse_init_info(
{
"master_address": "127.0.0.1",
"master_port": 12345,
"rank_offset": 1,
"world_size": 3,
}
)
assert isinstance(init_info, NCCLWeightTransferInitInfo)
assert init_info.master_address == "127.0.0.1"
assert init_info.master_port == 12345
assert init_info.rank_offset == 1
assert init_info.world_size == 3
def test_parse_init_info_missing_field_raises(self):
"""Test parsing init info with missing required field."""
config = WeightTransferConfig(backend="nccl")
parallel_config = create_mock_parallel_config()
engine = NCCLWeightTransferEngine(config, parallel_config)
with pytest.raises(ValueError, match="Invalid init_info"):
engine.parse_init_info(
{
"master_address": "127.0.0.1",
# Missing master_port, rank_offset, world_size
}
)
def test_parse_update_info_valid(self):
"""Test parsing valid update info dict."""
config = WeightTransferConfig(backend="nccl")
parallel_config = create_mock_parallel_config()
engine = NCCLWeightTransferEngine(config, parallel_config)
update_info = engine.parse_update_info(
{
"names": ["w1", "w2"],
"dtype_names": ["float32", "bfloat16"],
"shapes": [[100, 100], [50]],
}
)
assert isinstance(update_info, NCCLWeightTransferUpdateInfo)
assert update_info.names == ["w1", "w2"]
assert update_info.dtype_names == ["float32", "bfloat16"]
assert update_info.shapes == [[100, 100], [50]]
# --- Unit Tests: Engine Registry ---
class TestEngineRegistry:
"""Test weight transfer engine registry."""
def test_create_engine_nccl(self):
"""Test factory creates NCCL engine."""
config = WeightTransferConfig(backend="nccl")
parallel_config = create_mock_parallel_config()
engine = WeightTransferEngineFactory.create_engine(config, parallel_config)
assert isinstance(engine, NCCLWeightTransferEngine)
def test_create_engine_invalid_backend(self):
"""Test factory raises for invalid backend."""
config = WeightTransferConfig(backend="invalid")
parallel_config = create_mock_parallel_config()
with pytest.raises(ValueError, match="Invalid weight transfer backend"):
WeightTransferEngineFactory.create_engine(config, parallel_config)
def test_register_duplicate_raises(self):
"""Test registering duplicate engine name raises."""
with pytest.raises(ValueError, match="already registered"):
WeightTransferEngineFactory.register_engine(
"nccl", NCCLWeightTransferEngine
)
# --- Test receive_weights without init raises ---
def test_nccl_receive_weights_without_init_raises():
"""Test that receive_weights raises if init_transfer_engine wasn't called."""
if torch.cuda.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
config = WeightTransferConfig(backend="nccl")
parallel_config = create_mock_parallel_config()
engine = NCCLWeightTransferEngine(config, parallel_config)
update_info = NCCLWeightTransferUpdateInfo(
names=["w"],
dtype_names=["float32"],
shapes=[[10]],
)
with pytest.raises(RuntimeError, match="not initialized"):
engine.receive_weights(update_info, lambda x: None)
# --- Integration Test: NCCL Weight Transfer Between Ray Tasks ---
@ray.remote(num_gpus=1)
def trainer_broadcast_tensor(
master_address: str,
master_port: int,
world_size: int,
tensor_shape: list[int],
tensor_dtype: str,
) -> bool:
"""Trainer task that broadcasts a tensor via NCCL."""
import torch
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
# Create process group as rank 0 (trainer)
pg = StatelessProcessGroup.create(
host=master_address,
port=master_port,
rank=0,
world_size=world_size,
)
# Ray sets CUDA_VISIBLE_DEVICES, so device 0 is the assigned GPU
comm = PyNcclCommunicator(pg, device=0)
# Create and broadcast the tensor
dtype = getattr(torch, tensor_dtype)
tensor_to_send = torch.ones(tensor_shape, dtype=dtype, device="cuda:0")
comm.broadcast(tensor_to_send, src=0, stream=torch.cuda.current_stream())
torch.cuda.synchronize()
return True
@ray.remote(num_gpus=1)
def inference_receive_tensor(
master_address: str,
master_port: int,
world_size: int,
tensor_shape: list[int],
tensor_dtype: str,
) -> dict:
"""Inference task that receives tensor via NCCLWeightTransferEngine."""
from unittest.mock import MagicMock
import torch
from vllm.config.parallel import ParallelConfig
from vllm.config.weight_transfer import WeightTransferConfig
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLWeightTransferEngine,
NCCLWeightTransferInitInfo,
NCCLWeightTransferUpdateInfo,
)
# Create engine with mock parallel config
config = WeightTransferConfig(backend="nccl")
parallel_config = MagicMock(spec=ParallelConfig)
parallel_config.rank = 0
parallel_config.world_size = 1
parallel_config.data_parallel_rank = 0
engine = NCCLWeightTransferEngine(config, parallel_config)
# Initialize the engine (joins as rank 1)
init_info = NCCLWeightTransferInitInfo(
master_address=master_address,
master_port=master_port,
rank_offset=1, # Trainer is rank 0, we become rank 1
world_size=world_size,
)
engine.init_transfer_engine(init_info)
# Receive weights with a no-op load_weights that captures the tensor
received_tensors = []
def noop_load_weights(weights: list[tuple[str, torch.Tensor]]):
for name, tensor in weights:
# Clone tensor to keep it after engine cleans up
received_tensors.append((name, tensor.clone()))
update_info = NCCLWeightTransferUpdateInfo(
names=["test.weight"],
dtype_names=[tensor_dtype],
shapes=[tensor_shape],
)
engine.receive_weights(update_info, noop_load_weights)
torch.cuda.synchronize()
# Verify we received the tensor
success = False
received_shape = None
received_sum = None
if len(received_tensors) == 1:
name, tensor = received_tensors[0]
received_shape = list(tensor.shape)
received_sum = tensor.sum().item()
# Check shape matches and values are all 1s (trainer sends ones)
if received_shape == tensor_shape:
expected_sum = 1.0 * torch.tensor(tensor_shape).prod().item()
if abs(received_sum - expected_sum) < 0.01:
success = True
engine.shutdown()
return {
"success": success,
"received_shape": received_shape,
"received_sum": received_sum,
}
@pytest.mark.skipif(
torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run NCCL weight transfer test.",
)
def test_nccl_weight_transfer_between_processes():
"""Test NCCL weight transfer from trainer to inference process using Ray.
This test verifies that the NCCLWeightTransferEngine can receive
tensors broadcast by a trainer process via NCCL.
"""
ray.init(ignore_reinit_error=True)
master_address = "127.0.0.1"
master_port = get_open_port()
world_size = 2 # 1 trainer + 1 inference worker
# Tensor to transfer: 100x100 ones
tensor_shape = [100, 100]
tensor_dtype = "float32"
# Start both tasks concurrently - Ray assigns GPUs automatically
inference_future = inference_receive_tensor.remote(
master_address, master_port, world_size, tensor_shape, tensor_dtype
)
trainer_future = trainer_broadcast_tensor.remote(
master_address, master_port, world_size, tensor_shape, tensor_dtype
)
# Wait for both to complete
trainer_result, result = ray.get([trainer_future, inference_future])
assert trainer_result, "Trainer should complete successfully"
assert result["success"], (
f"Weight transfer failed. "
f"Received shape: {result['received_shape']}, "
f"Received sum: {result['received_sum']}"
)
...@@ -139,6 +139,14 @@ def test_openapi_stateless(case: schemathesis.Case): ...@@ -139,6 +139,14 @@ def test_openapi_stateless(case: schemathesis.Case):
# Skip responses API as it is meant to be stateful. # Skip responses API as it is meant to be stateful.
return return
# Skip weight transfer endpoints as they require special setup
# (weight_transfer_config) and are meant to be stateful.
if case.operation.path in (
"/init_weight_transfer_engine",
"/update_weights",
):
return
timeout = { timeout = {
# requires a longer timeout # requires a longer timeout
("POST", "/v1/chat/completions"): LONG_TIMEOUT_SECONDS, ("POST", "/v1/chat/completions"): LONG_TIMEOUT_SECONDS,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for weight transfer APIs via LLM class.
These tests use a mock weight transfer engine to verify that the API
calls the correct methods with the right arguments, without requiring
actual NCCL communication.
"""
import os
from collections.abc import Callable
from dataclasses import dataclass
from unittest.mock import patch
import pytest
import torch
from vllm import LLM
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.base import (
WeightTransferEngine,
WeightTransferInitInfo,
WeightTransferInitRequest,
WeightTransferUpdateInfo,
WeightTransferUpdateRequest,
)
from ...utils import create_new_process_for_each_test
# Use a tiny model for fast testing
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
# --- Mock Weight Transfer Engine ---
@dataclass
class MockInitInfo(WeightTransferInitInfo):
"""Mock initialization info."""
test_param: str = "test"
@dataclass
class MockUpdateInfo(WeightTransferUpdateInfo):
"""Mock update info."""
names: list[str] | None = None
dtype_names: list[str] | None = None
shapes: list[list[int]] | None = None
class MockWeightTransferEngine(WeightTransferEngine[MockInitInfo, MockUpdateInfo]):
"""Mock weight transfer engine that tracks method calls."""
init_info_cls = MockInitInfo
update_info_cls = MockUpdateInfo
# Class-level tracking for verification across processes
init_transfer_engine_called: bool = False
receive_weights_called: bool = False
shutdown_called: bool = False
last_init_info: MockInitInfo | None = None
last_update_info: MockUpdateInfo | None = None
def __init__(self, config, parallel_config):
super().__init__(config, parallel_config)
# Reset tracking on init
MockWeightTransferEngine.init_transfer_engine_called = False
MockWeightTransferEngine.receive_weights_called = False
MockWeightTransferEngine.shutdown_called = False
MockWeightTransferEngine.last_init_info = None
MockWeightTransferEngine.last_update_info = None
def init_transfer_engine(self, init_info: MockInitInfo) -> None:
MockWeightTransferEngine.init_transfer_engine_called = True
MockWeightTransferEngine.last_init_info = init_info
def receive_weights(
self,
update_info: MockUpdateInfo,
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
) -> None:
MockWeightTransferEngine.receive_weights_called = True
MockWeightTransferEngine.last_update_info = update_info
# Simulate loading weights by calling load_weights with empty list
# (In real implementation, this would receive and load actual weights)
load_weights([])
def shutdown(self) -> None:
MockWeightTransferEngine.shutdown_called = True
def mock_create_engine(config, parallel_config):
"""Mock factory function that returns our mock engine."""
return MockWeightTransferEngine(config, parallel_config)
# --- Tests ---
@create_new_process_for_each_test()
def test_get_world_size_tp1():
"""Test world_size is correctly configured for TP=1."""
if torch.cuda.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
llm = LLM(
model=MODEL_NAME,
enforce_eager=True,
load_format="dummy",
tensor_parallel_size=1,
weight_transfer_config=WeightTransferConfig(backend="nccl"),
)
world_size = llm.llm_engine.vllm_config.parallel_config.world_size
assert world_size == 1
@create_new_process_for_each_test()
def test_init_weight_transfer_engine_calls_engine():
"""Test that init_weight_transfer_engine calls the engine's
init_transfer_engine method."""
if torch.cuda.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
# Enable insecure serialization to allow pickling functions for collective_rpc
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
with patch(
"vllm.v1.worker.gpu_worker.WeightTransferEngineFactory.create_engine",
mock_create_engine,
):
llm = LLM(
model=MODEL_NAME,
enforce_eager=True,
load_format="dummy",
tensor_parallel_size=1,
weight_transfer_config=WeightTransferConfig(backend="nccl"),
)
# Verify engine was created
def check_engine_exists(self):
return self.weight_transfer_engine is not None
results = llm.collective_rpc(check_engine_exists)
assert all(results), "Weight transfer engine should be initialized"
# Call init_weight_transfer_engine
llm.init_weight_transfer_engine(
WeightTransferInitRequest(init_info={"test_param": "hello"})
)
# Verify init_transfer_engine was called on the engine
def check_init_called(self):
engine = self.weight_transfer_engine
return (
engine.init_transfer_engine_called,
engine.last_init_info.test_param if engine.last_init_info else None,
)
results = llm.collective_rpc(check_init_called)
for called, param in results:
assert called, "init_transfer_engine should have been called"
assert param == "hello", f"Expected 'hello', got {param}"
@create_new_process_for_each_test()
def test_update_weights_calls_engine():
"""Test that update_weights calls the engine's receive_weights method."""
if torch.cuda.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
# Enable insecure serialization to allow pickling functions for collective_rpc
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
with patch(
"vllm.v1.worker.gpu_worker.WeightTransferEngineFactory.create_engine",
mock_create_engine,
):
llm = LLM(
model=MODEL_NAME,
enforce_eager=True,
load_format="dummy",
tensor_parallel_size=1,
weight_transfer_config=WeightTransferConfig(backend="nccl"),
)
# First init the weight transfer
llm.init_weight_transfer_engine(
WeightTransferInitRequest(init_info={"test_param": "init"})
)
# Call update_weights
test_names = ["layer.weight", "layer.bias"]
test_dtypes = ["float32", "float32"]
test_shapes = [[10, 10], [10]]
llm.update_weights(
WeightTransferUpdateRequest(
update_info={
"names": test_names,
"dtype_names": test_dtypes,
"shapes": test_shapes,
}
)
)
# Verify receive_weights was called with correct info
def check_update_called(self):
engine = self.weight_transfer_engine
if not engine.receive_weights_called:
return False, None, None, None
info = engine.last_update_info
return (True, info.names, info.dtype_names, info.shapes)
results = llm.collective_rpc(check_update_called)
for called, names, dtypes, shapes in results:
assert called, "receive_weights should have been called"
assert names == test_names
assert dtypes == test_dtypes
assert shapes == test_shapes
@create_new_process_for_each_test()
def test_full_weight_transfer_flow():
"""Test the complete weight transfer flow: init -> update."""
if torch.cuda.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
# Enable insecure serialization to allow pickling functions for collective_rpc
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
with patch(
"vllm.v1.worker.gpu_worker.WeightTransferEngineFactory.create_engine",
mock_create_engine,
):
llm = LLM(
model=MODEL_NAME,
enforce_eager=True,
load_format="dummy",
tensor_parallel_size=1,
weight_transfer_config=WeightTransferConfig(backend="nccl"),
)
# Step 1: Initialize
llm.init_weight_transfer_engine(
WeightTransferInitRequest(init_info={"test_param": "flow_test"})
)
# Step 2: Update weights
llm.update_weights(
WeightTransferUpdateRequest(
update_info={
"names": ["test.weight"],
"dtype_names": ["bfloat16"],
"shapes": [[100, 100]],
}
)
)
# Verify the full flow completed
def check_flow(self):
engine = self.weight_transfer_engine
return {
"init_called": engine.init_transfer_engine_called,
"update_called": engine.receive_weights_called,
"init_param": (
engine.last_init_info.test_param if engine.last_init_info else None
),
"update_names": (
engine.last_update_info.names if engine.last_update_info else None
),
}
results = llm.collective_rpc(check_flow)
for result in results:
assert result["init_called"], "init_transfer_engine should be called"
assert result["update_called"], "receive_weights should be called"
assert result["init_param"] == "flow_test"
assert result["update_names"] == ["test.weight"]
@create_new_process_for_each_test()
def test_weight_transfer_config_backend():
"""Test that WeightTransferConfig backend is properly configured."""
if torch.cuda.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
# Test with nccl backend
llm = LLM(
model=MODEL_NAME,
enforce_eager=True,
load_format="dummy",
tensor_parallel_size=1,
weight_transfer_config=WeightTransferConfig(backend="nccl"),
)
config = llm.llm_engine.vllm_config.weight_transfer_config
assert config.backend == "nccl"
...@@ -47,6 +47,7 @@ from vllm.config.vllm import ( ...@@ -47,6 +47,7 @@ from vllm.config.vllm import (
get_layers_from_vllm_config, get_layers_from_vllm_config,
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.config.weight_transfer import WeightTransferConfig
# __all__ should only contain classes and functions. # __all__ should only contain classes and functions.
# Types and globals should be imported from their respective modules. # Types and globals should be imported from their respective modules.
...@@ -111,4 +112,5 @@ __all__ = [ ...@@ -111,4 +112,5 @@ __all__ = [
"get_current_vllm_config_or_none", "get_current_vllm_config_or_none",
"set_current_vllm_config", "set_current_vllm_config",
"get_layers_from_vllm_config", "get_layers_from_vllm_config",
"WeightTransferConfig",
] ]
...@@ -42,6 +42,7 @@ from .scheduler import SchedulerConfig ...@@ -42,6 +42,7 @@ from .scheduler import SchedulerConfig
from .speculative import EagleModelTypes, SpeculativeConfig from .speculative import EagleModelTypes, SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig from .structured_outputs import StructuredOutputsConfig
from .utils import SupportsHash, config, replace from .utils import SupportsHash, config, replace
from .weight_transfer import WeightTransferConfig
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -255,6 +256,9 @@ class VllmConfig: ...@@ -255,6 +256,9 @@ class VllmConfig:
performance. -02 is used by defult. See OptimizationLevel for full performance. -02 is used by defult. See OptimizationLevel for full
description.""" description."""
weight_transfer_config: WeightTransferConfig | None = None
"""The configurations for weight transfer during RL training."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Literal
from vllm.config.utils import config
@config
@dataclass
class WeightTransferConfig:
"""Configuration for weight transfer during RL training."""
backend: Literal["nccl"] = "nccl"
"""The backend to use for weight transfer."""
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Weight transfer engines for syncing model weights from trainers
to inference workers.
"""
from vllm.distributed.weight_transfer.factory import WeightTransferEngineFactory
__all__ = [
"WeightTransferEngineFactory",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Base class for weight transfer engines."""
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import KW_ONLY, dataclass, field
from typing import Any, Generic, TypeVar
import torch
from vllm.config.parallel import ParallelConfig
from vllm.config.weight_transfer import WeightTransferConfig
TInitInfo = TypeVar("TInitInfo", bound="WeightTransferInitInfo")
TUpdateInfo = TypeVar("TUpdateInfo", bound="WeightTransferUpdateInfo")
# Base protocols for backend-specific dataclasses
@dataclass
class WeightTransferInitInfo(ABC): # noqa: B024
"""Base class for backend-specific initialization info."""
pass
@dataclass
class WeightTransferUpdateInfo(ABC): # noqa: B024
"""Base class for backend-specific weight update info."""
_: KW_ONLY
is_checkpoint_format: bool = True
"""Set to True if weights are in checkpoint/original model format and need
layerwise processing. Set to False if weights have already been processed
into kernel format (repacking, renaming, etc.)."""
# API-level request classes (accept dicts for backend-agnostic serialization)
@dataclass
class WeightTransferInitRequest:
"""API-level weight transfer initialization request."""
init_info: dict[str, Any] = field(default_factory=dict)
@dataclass
class WeightTransferUpdateRequest:
"""API-level weight update request."""
update_info: dict[str, Any] = field(default_factory=dict)
class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]):
"""
Base class for weight transfer engines that handle transport of model weights
from a trainer to inference workers.
This abstraction separates weight transfer transport logic from the worker
implementation, allowing different backends (NCCL, CUDA IPC[TODO], RDMA[TODO]) to be
plugged in.
Subclasses should define:
init_info_cls: Type of backend-specific initialization info
update_info_cls: Type of backend-specific update info
"""
# Subclasses should override these class attributes
init_info_cls: type[TInitInfo]
update_info_cls: type[TUpdateInfo]
def __init__(
self, config: WeightTransferConfig, parallel_config: ParallelConfig
) -> None:
"""
Initialize the weight transfer engine.
Args:
config: The configuration for the weight transfer engine
parallel_config: The configuration for the parallel setup
"""
self.config = config
self.parallel_config = parallel_config
def parse_init_info(self, init_dict: dict[str, Any]) -> TInitInfo:
"""
Construct typed init info from dict with validation.
Args:
init_dict: Dictionary containing backend-specific initialization parameters
Returns:
Typed backend-specific init info dataclass
Raises:
ValueError: If init_dict is invalid for this backend
"""
try:
return self.init_info_cls(**init_dict)
except TypeError as e:
raise ValueError(
f"Invalid init_info for {self.__class__.__name__}: {e}"
) from e
def parse_update_info(self, update_dict: dict[str, Any]) -> TUpdateInfo:
"""
Construct typed update info from dict with validation.
Args:
update_dict: Dictionary containing backend-specific update parameters
Returns:
Typed backend-specific update info dataclass
Raises:
ValueError: If update_dict is invalid for this backend
"""
try:
return self.update_info_cls(**update_dict)
except TypeError as e:
raise ValueError(
f"Invalid update_info for {self.__class__.__name__}: {e}"
) from e
@abstractmethod
def init_transfer_engine(self, init_info: TInitInfo) -> None:
"""
Initialize the weight transfer mechanism.
This is called once at the beginning of training.
Args:
init_info: Backend-specific initialization info
"""
raise NotImplementedError
@abstractmethod
def receive_weights(
self,
update_info: TUpdateInfo,
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
) -> None:
"""
Receive weights from the trainer and load them incrementally.
Args:
update_info: Backend-specific update info containing parameter metadata
and any backend-specific data
load_weights: Callable that loads weights into the model. Called
incrementally for each weight to avoid OOM.
"""
raise NotImplementedError
@abstractmethod
def shutdown(self) -> None:
"""
Shutdown the weight transfer engine.
This should be called when the worker is shutting down.
"""
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Factory for weight transfer engines with lazy loading."""
import importlib
from collections.abc import Callable
from typing import TYPE_CHECKING
from vllm.distributed.weight_transfer.base import WeightTransferEngine
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.config.parallel import ParallelConfig
from vllm.config.weight_transfer import WeightTransferConfig
logger = init_logger(__name__)
class WeightTransferEngineFactory:
"""Factory for creating weight transfer engines with lazy loading.
This factory implements a registry pattern that supports:
- Lazy loading: Engine modules are only imported when actually needed
- Extensibility: Custom engines can be registered at runtime
- Centralized registration: All built-in engines registered in one place
"""
_registry: dict[str, Callable[[], type[WeightTransferEngine]]] = {}
@classmethod
def register_engine(
cls,
name: str,
module_path_or_cls: str | type[WeightTransferEngine],
class_name: str | None = None,
) -> None:
"""Register an engine with lazy-loading or direct class reference.
Supports two calling conventions:
1. Lazy loading: register_engine(name, module_path, class_name)
2. Direct class: register_engine(name, engine_cls)
Args:
name: The name to register the engine under (e.g., "nccl")
module_path_or_cls: Either a module path string for lazy loading,
or the engine class directly
class_name: Name of the engine class (required if module_path is string)
Raises:
ValueError: If an engine with the same name is already registered
"""
if name in cls._registry:
raise ValueError(f"Weight transfer engine '{name}' is already registered.")
if isinstance(module_path_or_cls, str):
# Lazy loading path
module_path = module_path_or_cls
if class_name is None:
raise ValueError(
"class_name is required when registering with module path"
)
def loader() -> type[WeightTransferEngine]:
module = importlib.import_module(module_path)
return getattr(module, class_name)
cls._registry[name] = loader
else:
# Direct class registration
engine_cls = module_path_or_cls
cls._registry[name] = lambda: engine_cls
@classmethod
def create_engine(
cls,
config: "WeightTransferConfig",
parallel_config: "ParallelConfig",
) -> WeightTransferEngine:
"""Create a weight transfer engine instance.
Args:
config: Weight transfer configuration containing the backend name
parallel_config: Parallel configuration for the engine
Returns:
An initialized weight transfer engine instance
Raises:
ValueError: If the backend is not registered
"""
backend = config.backend
if backend not in cls._registry:
available = list(cls._registry.keys())
raise ValueError(
f"Invalid weight transfer backend: {backend}. "
f"Available engines: {available}"
)
engine_cls = cls._registry[backend]()
logger.info(
"Creating weight transfer engine: %s",
engine_cls.__name__,
)
return engine_cls(config, parallel_config)
# Register built-in weight transfer engines here.
# Registration should be centralized to ensure lazy loading -
# engine modules are only imported when actually used.
WeightTransferEngineFactory.register_engine(
"nccl",
"vllm.distributed.weight_transfer.nccl_engine",
"NCCLWeightTransferEngine",
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""NCCL-based weight transfer engine."""
from collections.abc import Callable, Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import torch
if TYPE_CHECKING:
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.config.parallel import ParallelConfig
from vllm.config.weight_transfer import WeightTransferConfig
from vllm.distributed.weight_transfer.base import (
WeightTransferEngine,
WeightTransferInitInfo,
WeightTransferUpdateInfo,
)
from vllm.distributed.weight_transfer.packed_tensor import (
DEFAULT_PACKED_BUFFER_SIZE_BYTES,
DEFAULT_PACKED_NUM_BUFFERS,
packed_broadcast_consumer,
)
@dataclass
class NCCLWeightTransferInitInfo(WeightTransferInitInfo):
"""Initialization info for NCCL weight transfer backend."""
master_address: str
master_port: int
rank_offset: int
world_size: int
@dataclass
class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
"""Update info for NCCL weight transfer backend."""
names: list[str]
dtype_names: list[str]
shapes: list[list[int]]
packed: bool = False
"""Whether to use packed tensor broadcasting for efficiency.
When True, multiple tensors are batched together before broadcasting
to reduce NCCL communication overhead."""
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
"""Size in bytes for each packed tensor buffer. Default is 1GB.
Both producer and consumer must use the same value."""
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
"""Number of buffers for double/triple buffering during packed transfer.
Both producer and consumer must use the same value."""
def __post_init__(self):
"""Validate that all lists have the same length."""
num_params = len(self.names)
if len(self.dtype_names) != num_params:
raise ValueError(
f"`dtype_names` should be of the same size as `names`: "
f"got {len(self.dtype_names)} and {len(self.names)}"
)
if len(self.shapes) != num_params:
raise ValueError(
f"`shapes` should be of the same size as `names`: "
f"got {len(self.shapes)} and {len(self.names)}"
)
class NCCLWeightTransferEngine(
WeightTransferEngine[NCCLWeightTransferInitInfo, NCCLWeightTransferUpdateInfo]
):
"""
Weight transfer engine using NCCL for communication between trainer and workers.
This implementation uses NCCL broadcast operations to transfer weights from
the trainer (rank 0) to all inference workers in a process group.
"""
# Define backend-specific dataclass types
init_info_cls = NCCLWeightTransferInitInfo
update_info_cls = NCCLWeightTransferUpdateInfo
def __init__(
self, config: WeightTransferConfig, parallel_config: ParallelConfig
) -> None:
"""
Initialize the NCCL weight transfer engine.
Args:
config: The configuration for the weight transfer engine
parallel_config: The configuration for the parallel setup
"""
super().__init__(config, parallel_config)
self.model_update_group: PyNcclCommunicator | None = None
def init_transfer_engine(self, init_info: NCCLWeightTransferInitInfo) -> None:
"""
Initialize NCCL process group with the trainer.
Args:
init_info: NCCL initialization info containing master address, port,
rank offset, and world size
"""
# Calculate the global rank in the trainer-worker process group
# Must account for data parallel to get unique ranks across all workers
dp_rank = self.parallel_config.data_parallel_rank
world_size_per_dp = self.parallel_config.world_size # TP * PP
rank_within_dp = self.parallel_config.rank
# Unique rank across all DP groups
worker_rank = dp_rank * world_size_per_dp + rank_within_dp
rank = worker_rank + init_info.rank_offset
# Create stateless process group
self.model_update_group = (
NCCLWeightTransferEngine._stateless_init_process_group(
init_info.master_address,
init_info.master_port,
rank,
init_info.world_size,
torch.cuda.current_device(),
)
)
def receive_weights(
self,
update_info: NCCLWeightTransferUpdateInfo,
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
) -> None:
"""
Receive weights from trainer via NCCL broadcast and load them incrementally.
If update_info.packed is True, uses packed tensor broadcasting for
efficient transfer of multiple weights in batches. Otherwise, uses simple
one-by-one broadcasting.
Args:
update_info: NCCL update info containing parameter names, dtypes, shapes,
and packed flag
load_weights: Callable that loads weights into the model. Called
incrementally for each batch of weights to avoid OOM.
"""
if self.model_update_group is None:
raise RuntimeError(
"NCCL weight transfer not initialized. "
"Call init_transfer_engine() first."
)
if update_info.packed:
# Build iterator of (name, (shape, dtype)) from update_info
def state_dict_info_iterator():
for name, dtype_name, shape in zip(
update_info.names, update_info.dtype_names, update_info.shapes
):
dtype = getattr(torch, dtype_name)
yield (name, (shape, dtype))
packed_broadcast_consumer(
iterator=state_dict_info_iterator(),
group=self.model_update_group,
src=0,
post_unpack_func=load_weights,
buffer_size_bytes=update_info.packed_buffer_size_bytes,
num_buffers=update_info.packed_num_buffers,
)
else:
# Use simple one-by-one broadcasting
for name, dtype_name, shape in zip(
update_info.names, update_info.dtype_names, update_info.shapes
):
dtype = getattr(torch, dtype_name)
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(
weight, src=0, stream=torch.cuda.current_stream()
)
load_weights([(name, weight)])
del weight
def shutdown(self) -> None:
if self.model_update_group is not None:
# Clean up the communicator by removing the reference
self.model_update_group = None
@staticmethod
def trainer_send_weights(
iterator: Iterator[tuple[str, torch.Tensor]],
group: Any,
src: int = 0,
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor]
| None = None,
packed: bool = False,
stream: torch.cuda.Stream | None = None,
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
) -> None:
"""Broadcast weights from trainer to vLLM workers.
Args:
iterator: Iterator of model parameters. Returns (name, tensor) tuples
group: Process group (PyNcclCommunicator)
src: Source rank (default 0, trainer is typically rank 0)
post_iter_func: Optional function to apply to each (name, tensor) pair
before broadcasting. If None, extracts just the tensor.
packed: Whether to use packed tensor broadcasting for efficiency.
When True, multiple tensors are batched together before
broadcasting to reduce NCCL communication overhead.
stream: CUDA stream to use for broadcasting if packed is False.
If packed is True, new streams will be created for each buffer.
packed_buffer_size_bytes: Size in bytes for each packed tensor buffer.
Must match the value used in NCCLWeightTransferUpdateInfo.
packed_num_buffers: Number of buffers for double/triple buffering.
Must match the value used in NCCLWeightTransferUpdateInfo.
Example:
>>> from vllm.distributed.weight_transfer.nccl_engine import (
... NCCLWeightTransferEngine,
... )
>>> param_iter = ((n, p) for n, p in model.named_parameters())
>>> NCCLWeightTransferEngine.trainer_send_weights(
... param_iter, group, packed=True
... )
"""
if post_iter_func is None:
# Default: extract just the tensor from (name, tensor) tuple
post_iter_func = lambda x: x[1]
if packed:
# Use packed tensor broadcasting for efficiency
from vllm.distributed.weight_transfer.packed_tensor import (
packed_broadcast_producer,
)
packed_broadcast_producer(
iterator=iterator,
group=group,
src=src,
post_iter_func=post_iter_func,
buffer_size_bytes=packed_buffer_size_bytes,
num_buffers=packed_num_buffers,
)
else:
# Use simple one-by-one broadcasting
for item in iterator:
tensor = post_iter_func(item)
group.broadcast(
tensor, src=src, stream=stream or torch.cuda.current_stream()
)
@staticmethod
def trainer_init(
init_info: NCCLWeightTransferInitInfo | dict,
) -> "PyNcclCommunicator":
"""
Initialize NCCL process group for trainer-side weight transfer.
The trainer is always rank 0 in the process group. Uses the current
CUDA device (torch.cuda.current_device()).
Args:
init_info: Either an NCCLWeightTransferInitInfo object or a dict with keys:
- master_address: str
- master_port: int
- world_size: int
Returns:
PyNcclCommunicator for weight transfer.
Example:
>>> from vllm.distributed.weight_transfer.nccl_engine import (
... NCCLWeightTransferEngine,
... )
>>> group = NCCLWeightTransferEngine.trainer_init(
... dict(
... master_address=master_address,
... master_port=master_port,
... world_size=world_size,
... ),
... )
"""
if isinstance(init_info, dict):
master_address = init_info["master_address"]
master_port = init_info["master_port"]
world_size = init_info["world_size"]
else:
# NCCLWeightTransferInitInfo object
master_address = init_info.master_address
master_port = init_info.master_port
world_size = init_info.world_size
# Trainer is always rank 0
return NCCLWeightTransferEngine._stateless_init_process_group(
master_address, master_port, 0, world_size, torch.cuda.current_device()
)
@staticmethod
def _stateless_init_process_group(
master_address, master_port, rank, world_size, device
):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
pg = StatelessProcessGroup.create(
host=master_address, port=master_port, rank=rank, world_size=world_size
)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Packed tensor utilities for efficient weight transfer."""
import math
from collections.abc import Callable, Iterator
from typing import Any
import torch
# Default values for packed tensor configuration.
# These are imported by NCCLWeightTransferUpdateInfo and trainer_send_weights.
DEFAULT_PACKED_BUFFER_SIZE_BYTES = 1024 * 1024 * 1024 # 1GB
DEFAULT_PACKED_NUM_BUFFERS = 2
def packed_broadcast_producer(
iterator: Iterator[tuple[str, torch.Tensor]],
group: Any,
src: int,
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor],
buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
) -> None:
"""Broadcast tensors in a packed manner from trainer to workers.
Args:
iterator: Iterator of model parameters. Returns a tuple of (name, tensor)
group: Process group (PyNcclCommunicator)
src: Source rank (0 in current implementation)
post_iter_func: Function to apply to each (name, tensor) pair before
packing, should return a tensor
buffer_size_bytes: Size in bytes for each packed tensor buffer.
Both producer and consumer must use the same value.
num_buffers: Number of buffers for double/triple buffering.
Both producer and consumer must use the same value.
"""
target_packed_tensor_size = buffer_size_bytes
streams = [torch.cuda.Stream() for _ in range(num_buffers)]
buffer_idx = 0
packing_tensor_list: list[list[torch.Tensor]] = [[] for _ in range(num_buffers)]
packing_tensor_sizes: list[int] = [0 for _ in range(num_buffers)]
packed_tensors: list[torch.Tensor] = [
torch.empty(0, dtype=torch.uint8, device="cuda") for _ in range(num_buffers)
]
while True:
# Synchronize the current stream
streams[buffer_idx].synchronize()
# Start tasks for the new buffer in a new stream
with torch.cuda.stream(streams[buffer_idx]):
try:
# Initialize the packing tensor list and sizes
packing_tensor_list[buffer_idx] = []
packing_tensor_sizes[buffer_idx] = 0
# Pack the tensors
while True:
# Apply post processing and convert to linearized uint8 tensor
tensor = (
post_iter_func(next(iterator))
.contiguous()
.view(torch.uint8)
.view(-1)
)
packing_tensor_list[buffer_idx].append(tensor)
packing_tensor_sizes[buffer_idx] += tensor.numel()
if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size:
break
# Pack the tensors and call broadcast collective
packed_tensors[buffer_idx] = torch.cat(
packing_tensor_list[buffer_idx], dim=0
)
group.broadcast(packed_tensors[buffer_idx], src=src)
# Move to the next buffer
buffer_idx = (buffer_idx + 1) % num_buffers
except StopIteration:
# Do the last broadcast if there are remaining tensors
if len(packing_tensor_list[buffer_idx]) > 0:
packed_tensors[buffer_idx] = torch.cat(
packing_tensor_list[buffer_idx], dim=0
)
group.broadcast(packed_tensors[buffer_idx], src=src)
break
def packed_broadcast_consumer(
iterator: Iterator[tuple[str, tuple[list[int], torch.dtype]]],
group: Any,
src: int,
post_unpack_func: Callable[[list[tuple[str, torch.Tensor]]], None],
buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
) -> None:
"""Consume packed tensors and unpack them into a list of tensors.
Args:
iterator: Iterator of parameter metadata. Returns (name, (shape, dtype))
group: Process group (PyNcclCommunicator)
src: Source rank (0 in current implementation)
post_unpack_func: Function to apply to each list of (name, tensor) after
unpacking
buffer_size_bytes: Size in bytes for each packed tensor buffer.
Both producer and consumer must use the same value.
num_buffers: Number of buffers for double/triple buffering.
Both producer and consumer must use the same value.
"""
def unpack_tensor(
packed_tensor: torch.Tensor,
names: list[str],
shapes: list[list[int]],
dtypes: list[torch.dtype],
tensor_sizes: list[int],
) -> list[tuple[str, torch.Tensor]]:
"""Unpack a single tensor into a list of tensors.
Args:
packed_tensor: The packed torch.uint8 tensor to unpack
names: List of tensor names
shapes: List of tensor shapes
dtypes: List of tensor dtypes
tensor_sizes: List of tensor sizes in bytes
Returns:
unpacked List[(name, tensor)]
"""
unpacked_tensors = packed_tensor.split(tensor_sizes)
unpacked_list = [
(name, tensor.contiguous().view(dtype).view(*shape))
for name, shape, dtype, tensor in zip(
names, shapes, dtypes, unpacked_tensors
)
]
return unpacked_list
target_packed_tensor_size = buffer_size_bytes
streams = [torch.cuda.Stream() for _ in range(num_buffers)]
buffer_idx = 0
packing_tensor_meta_data: list[list[tuple[str, list[int], torch.dtype, int]]] = [
[] for _ in range(num_buffers)
]
packing_tensor_sizes: list[int] = [0 for _ in range(num_buffers)]
packed_tensors: list[torch.Tensor] = [
torch.empty(0, dtype=torch.uint8, device="cuda") for _ in range(num_buffers)
]
while True:
# Synchronize the current stream
streams[buffer_idx].synchronize()
with torch.cuda.stream(streams[buffer_idx]):
# Initialize the packing tensor meta data
packing_tensor_meta_data[buffer_idx] = []
packing_tensor_sizes[buffer_idx] = 0
try:
# Form a packed tensor
while True:
name, (shape, dtype) = next(iterator)
tensor_size = math.prod(shape) * dtype.itemsize
packing_tensor_meta_data[buffer_idx].append(
(name, shape, dtype, tensor_size)
)
packing_tensor_sizes[buffer_idx] += tensor_size
if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size:
break
# Create a packed tensor and broadcast it
packed_tensors[buffer_idx] = torch.empty(
packing_tensor_sizes[buffer_idx], dtype=torch.uint8, device="cuda"
)
group.broadcast(packed_tensors[buffer_idx], src=src)
# Load the packed tensor into the model
names, shapes, dtypes, tensor_sizes = zip(
*packing_tensor_meta_data[buffer_idx]
)
post_unpack_func(
unpack_tensor(
packed_tensors[buffer_idx],
list(names),
list(shapes),
list(dtypes),
list(tensor_sizes),
)
)
# Move to the next buffer
buffer_idx = (buffer_idx + 1) % num_buffers
except StopIteration:
# Do the last broadcast if there are remaining tensors
if len(packing_tensor_meta_data[buffer_idx]) > 0:
# Create a packed tensor and broadcast it
packed_tensors[buffer_idx] = torch.empty(
packing_tensor_sizes[buffer_idx],
dtype=torch.uint8,
device="cuda",
)
group.broadcast(packed_tensors[buffer_idx], src=src)
# Load the packed tensor into the model
names, shapes, dtypes, tensor_sizes = zip(
*packing_tensor_meta_data[buffer_idx]
)
post_unpack_func(
unpack_tensor(
packed_tensors[buffer_idx],
list(names),
list(shapes),
list(dtypes),
list(tensor_sizes),
)
)
break
...@@ -54,6 +54,7 @@ from vllm.config import ( ...@@ -54,6 +54,7 @@ from vllm.config import (
SpeculativeConfig, SpeculativeConfig,
StructuredOutputsConfig, StructuredOutputsConfig,
VllmConfig, VllmConfig,
WeightTransferConfig,
get_attr_docs, get_attr_docs,
) )
from vllm.config.cache import ( from vllm.config.cache import (
...@@ -581,6 +582,11 @@ class EngineArgs: ...@@ -581,6 +582,11 @@ class EngineArgs:
kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend
tokens_only: bool = False tokens_only: bool = False
weight_transfer_config: WeightTransferConfig | None = None
"""Configuration for weight transfer during RL training.
Accepts a JSON string or dict with backend-specific options.
Example: '{"backend": "nccl"}'"""
def __post_init__(self): def __post_init__(self):
# support `EngineArgs(compilation_config={...})` # support `EngineArgs(compilation_config={...})`
# without having to manually construct a # without having to manually construct a
...@@ -591,6 +597,10 @@ class EngineArgs: ...@@ -591,6 +597,10 @@ class EngineArgs:
self.attention_config = AttentionConfig(**self.attention_config) self.attention_config = AttentionConfig(**self.attention_config)
if isinstance(self.eplb_config, dict): if isinstance(self.eplb_config, dict):
self.eplb_config = EPLBConfig(**self.eplb_config) self.eplb_config = EPLBConfig(**self.eplb_config)
if isinstance(self.weight_transfer_config, dict):
self.weight_transfer_config = WeightTransferConfig(
**self.weight_transfer_config
)
# Setup plugins # Setup plugins
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
...@@ -1189,6 +1199,9 @@ class EngineArgs: ...@@ -1189,6 +1199,9 @@ class EngineArgs:
vllm_group.add_argument( vllm_group.add_argument(
"--optimization-level", **vllm_kwargs["optimization_level"] "--optimization-level", **vllm_kwargs["optimization_level"]
) )
vllm_group.add_argument(
"--weight-transfer-config", **vllm_kwargs["weight_transfer_config"]
)
# Other arguments # Other arguments
parser.add_argument( parser.add_argument(
...@@ -1765,6 +1778,7 @@ class EngineArgs: ...@@ -1765,6 +1778,7 @@ class EngineArgs:
profiler_config=self.profiler_config, profiler_config=self.profiler_config,
additional_config=self.additional_config, additional_config=self.additional_config,
optimization_level=self.optimization_level, optimization_level=self.optimization_level,
weight_transfer_config=self.weight_transfer_config,
) )
return config return config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment