Commit 3fb4b5fa authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.18.0' into v0.18.0-ori

parents bcf25339 89138b21
......@@ -5,6 +5,7 @@ from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.print_utils import print_embeddings
def parse_args():
......@@ -39,10 +40,8 @@ def main(args: Namespace):
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
embeds = output.outputs.embedding
embeds_trimmed = (
(str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
)
print(f"Prompt: {prompt!r} \nEmbeddings: {embeds_trimmed} (size={len(embeds)})")
print(f"Prompt: {prompt!r}")
print_embeddings(embeds)
print("-" * 60)
......
......@@ -5,6 +5,7 @@ from argparse import Namespace
from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.print_utils import print_embeddings
def parse_args():
......@@ -41,10 +42,8 @@ def main(args: Namespace):
print("\nGenerated Outputs:\n" + "-" * 60)
for prompt, output in zip(prompts, outputs):
rewards = output.outputs.data
rewards_trimmed = (
(str(rewards[:16])[:-1] + ", ...]") if len(rewards) > 16 else rewards
)
print(f"Prompt: {prompt!r} \nReward: {rewards_trimmed} (size={len(rewards)})")
print(f"Prompt: {prompt!r}")
print_embeddings(rewards, prefix="Reward")
print("-" * 60)
......
......@@ -201,6 +201,34 @@ def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
)
# Kimi-Audio-7B-Instruct
def run_kimi_audio(question: str, audio_count: int) -> ModelRequestData:
"""Kimi-Audio-7B-Instruct for audio transcription and understanding."""
model_name = "moonshotai/Kimi-Audio-7B-Instruct"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
)
# Kimi-Audio uses <|im_kimia_text_blank|> as placeholder for audio features
audio_placeholder = "<|im_kimia_text_blank|>" * audio_count
# Default prompt for transcription
if not question:
question = "Please transcribe the audio"
prompt = f"{audio_placeholder}{question}"
# Stop at EOS token (151644) to prevent repetition
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
stop_token_ids=[151644],
)
# MiDashengLM
def run_midashenglm(question: str, audio_count: int):
model_name = "mispeech/midashenglm-7b"
......@@ -485,6 +513,7 @@ model_example_map = {
"glmasr": run_glmasr,
"funaudiochat": run_funaudiochat,
"granite_speech": run_granite_speech,
"kimi_audio": run_kimi_audio,
"midashenglm": run_midashenglm,
"minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from safetensors import safe_open
from vllm import LLM, SamplingParams
# Example: Using the custom "extract_hidden_states" speculator method and
# ExampleHiddenStatesConnector to extract and save hidden states from vllm
with tempfile.TemporaryDirectory() as tmpdirname:
llm = LLM(
model="Qwen/Qwen3-8B", # Your target model
speculative_config={
"method": "extract_hidden_states",
"num_speculative_tokens": 1,
"draft_model_config": {
"hf_config": {
"eagle_aux_hidden_state_layer_ids": [ # Target model layer indices
1,
2,
3,
4,
],
}
},
},
kv_transfer_config={
"kv_connector": "ExampleHiddenStatesConnector",
"kv_role": "kv_producer",
"kv_connector_extra_config": {
"shared_storage_path": tmpdirname,
},
},
)
prompts = ["Generate a sentence with hidden states", "Write a python function"]
sampling_params = SamplingParams(max_tokens=1)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print("\nPrompt:", output.prompt)
print("Prompt token ids:", output.prompt_token_ids)
hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
assert hidden_states_path is not None
print("Prompt hidden states path:", hidden_states_path)
with safe_open(hidden_states_path, "pt") as f:
token_ids = f.get_tensor("token_ids")
hidden_states = f.get_tensor("hidden_states")
print("Extracted token ids:", token_ids) # Matches prompt token ids
print(
"Extracted hidden states shape:", hidden_states.shape
) # [num_hidden_layers, prompt len, hidden size]
print("Extracted hidden states:", hidden_states)
......@@ -28,3 +28,4 @@ It demonstrates vLLM's ability to recover from KV load failures in both synchron
```bash
./run.sh
```
......@@ -42,6 +42,7 @@ def main():
"async_load": args.async_load,
},
kv_connector_module_path="load_recovery_example_connector",
kv_load_failure_policy="recompute",
)
out_file = (
"async_decode_recovered_output.txt"
......
# Custom Logits Processors
This directory contains examples demonstrating how to use custom logits processors with vLLM's offline inference API. Logits processors allow you to modify the model's output distribution before sampling, enabling controlled generation behaviors like token masking, constrained decoding, and custom sampling strategies.
## Scripts
### `custom.py` — Engine-level logits processor
Demonstrates how to instantiate vLLM with a custom logits processor class that operates at the batch level. The example uses a `DummyLogitsProcessor` that masks out all tokens except a specified `target_token` when passed via `SamplingParams.extra_args`.
```bash
python examples/offline_inference/logits_processor/custom.py
```
### `custom_req.py` — Request-level logits processor wrapper
Shows how to wrap a request-level logits processor (which operates on individual requests) to be compatible with vLLM's batch-level logits processing interface.
```bash
python examples/offline_inference/logits_processor/custom_req.py
```
### `custom_req_init.py` — Request-level processor with engine config
A special case of wrapping a request-level logits processor where the processor needs access to engine configuration or model metadata during initialization (e.g., vocabulary size, tokenizer info).
```bash
python examples/offline_inference/logits_processor/custom_req_init.py
```
## Key Concepts
- **Batch-level vs. request-level**: vLLM processes logits at the batch level for efficiency. If you have a per-request processor, you need to wrap it using the patterns shown in `custom_req.py` and `custom_req_init.py`.
- **`SamplingParams.extra_args`**: Use this to pass custom keyword arguments to your logits processor on a per-request basis (e.g., `target_token`).
- **`DummyLogitsProcessor`**: A reference implementation available in `vllm/test_utils.py` that can be used as a starting point for custom processors.
## Further Reading
- [vLLM Sampling Parameters](https://docs.vllm.ai/en/latest/api/inference_params.html#sampling-parameters)
- [vLLM LLM API](https://docs.vllm.ai/en/latest/api/offline_inference/llm.html)
......@@ -120,7 +120,7 @@ def main():
# Clean up the GPU memory for the next test
del engine
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
if __name__ == "__main__":
......
......@@ -7,6 +7,7 @@ import argparse
from vllm import LLM
from vllm.sampling_params import SamplingParams
from vllm.assets.image import ImageAsset
from vllm.multimodal.utils import encode_image_url
# This script is an offline demo for running Mistral-Small-3.1
#
......@@ -18,11 +19,11 @@ from vllm.assets.image import ImageAsset
# # Mistral format
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
# --tokenizer-mode mistral --config-format mistral --load-format mistral \
# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384
# --limit-mm-per-prompt.image 4 --max-model-len 16384
#
# # HF format
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384
# --limit-mm-per-prompt.image 4 --max-model-len 16384
# ```
#
# - Client:
......@@ -61,9 +62,9 @@ def run_simple_demo(args: argparse.Namespace):
llm = LLM(
model=model_name,
tokenizer_mode="mistral" if args.format == "mistral" else "auto",
config_format="mistral" if args.format == "mistral" else "auto",
load_format="mistral" if args.format == "mistral" else "auto",
tokenizer_mode="mistral" if args.format == "mistral" else "hf",
config_format="mistral" if args.format == "mistral" else "hf",
load_format="mistral" if args.format == "mistral" else "hf",
limit_mm_per_prompt={"image": 1},
max_model_len=4096,
max_num_seqs=2,
......@@ -79,8 +80,10 @@ def run_simple_demo(args: argparse.Namespace):
"content": [
{"type": "text", "text": prompt},
{
"type": "image_pil",
"image_pil": ImageAsset("cherry_blossom").pil_image,
"type": "image_url",
"image_url": {
"url": encode_image_url(ImageAsset("cherry_blossom").pil_image)
},
},
],
},
......@@ -99,9 +102,9 @@ def run_advanced_demo(args: argparse.Namespace):
sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
llm = LLM(
model=model_name,
tokenizer_mode="mistral" if args.format == "mistral" else "auto",
config_format="mistral" if args.format == "mistral" else "auto",
load_format="mistral" if args.format == "mistral" else "auto",
tokenizer_mode="mistral" if args.format == "mistral" else "hf",
config_format="mistral" if args.format == "mistral" else "hf",
load_format="mistral" if args.format == "mistral" else "hf",
limit_mm_per_prompt={"image": max_img_per_msg},
max_model_len=max_img_per_msg * max_tokens_per_img,
tensor_parallel_size=2,
......
......@@ -26,14 +26,12 @@ workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
"""
import os
import asyncio
import uuid
from dataclasses import asdict
import ray
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoModelForCausalLM, AutoTokenizer
import vllm
......@@ -44,21 +42,25 @@ from vllm.distributed.weight_transfer.base import (
WeightTransferUpdateRequest,
)
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLTrainerSendWeightsArgs,
NCCLWeightTransferEngine,
NCCLWeightTransferInitInfo,
NCCLWeightTransferUpdateInfo,
)
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_ip, get_open_port
from vllm.v1.executor import Executor
MODEL_NAME = "facebook/opt-125m"
MODEL_NAME_V1 = "Qwen/Qwen3-1.7B-Base"
MODEL_NAME_V2 = "Qwen/Qwen3-1.7B"
PAUSE_TOKEN_THRESHOLD = 10
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "FLASH_ATTN"
class MyLLM(vllm.AsyncLLMEngine):
"""Configure the vLLM worker for Ray placement group execution."""
def __init__(self, **kwargs):
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0,1"
engine_args = vllm.AsyncEngineArgs(**kwargs)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
......@@ -68,26 +70,44 @@ class MyLLM(vllm.AsyncLLMEngine):
log_requests=engine_args.enable_log_requests,
log_stats=not engine_args.disable_log_stats,
)
self._generation_paused = False
self._request_pause_flag = False
async def generate_with_retry(
async def do_generate(
self, prompt_token_ids: list[int], sampling_params: vllm.SamplingParams
) -> vllm.RequestOutput:
finish_reason = "abort"
while finish_reason == "abort":
async for request_output in self.generate(
{"prompt_token_ids": prompt_token_ids},
sampling_params,
request_id=str(uuid.uuid4()),
) -> tuple[vllm.RequestOutput, int]:
"""Generate a single request, setting the request pause flag once the
token count reaches the threshold.
Returns (output, pause_token_index). pause_token_index is the number
of tokens generated before the weight change, or -1 if no pause.
"""
pause_token_index = -1
prev_token_count = 0
async for request_output in self.generate(
{"prompt_token_ids": prompt_token_ids},
sampling_params,
request_id=str(uuid.uuid4()),
):
output = request_output
cur_token_count = len(output.outputs[0].token_ids)
if (
cur_token_count >= PAUSE_TOKEN_THRESHOLD
and not self._request_pause_flag
):
output = request_output
finish_reason = output.outputs[0].finish_reason
if finish_reason == "abort":
print(
f"ABORT, prompt_token_ids: {prompt_token_ids}, "
f"generated token_ids: {list(output.outputs[0].token_ids)}"
)
prompt_token_ids = prompt_token_ids + list(output.outputs[0].token_ids)
return output
self._request_pause_flag = True
if self._generation_paused and pause_token_index == -1:
pause_token_index = prev_token_count
prev_token_count = cur_token_count
return output, pause_token_index
async def pause_after_n_tokens(self):
"""Wait for any request to set the pause flag, then pause."""
while not self._request_pause_flag:
await asyncio.sleep(0)
await super().pause_generation(mode="keep")
await asyncio.sleep(5)
self._generation_paused = True
@ray.remote(num_gpus=1)
......@@ -95,6 +115,20 @@ class TrainModel:
"""Ray actor that wraps the training model on a dedicated GPU."""
def __init__(self, model_name: str):
from vllm.model_executor.layers.batch_invariant import (
init_batch_invariance,
)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum
# need to init all env vars for batch invariance which affect nccl ops
attn_backend = (
AttentionBackendEnum.TRITON_ATTN
if current_platform.is_rocm()
else AttentionBackendEnum.FLASH_ATTN
)
init_batch_invariance(attn_backend)
self.model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.bfloat16
).to("cuda:0")
......@@ -127,76 +161,106 @@ class TrainModel:
def broadcast_weights(self, packed: bool = True):
"""Broadcast weights to the inference engine."""
NCCLWeightTransferEngine.trainer_send_weights(
iterator=self.model.named_parameters(),
trainer_args = NCCLTrainerSendWeightsArgs(
group=self.model_update_group,
packed=packed,
)
NCCLWeightTransferEngine.trainer_send_weights(
iterator=self.model.named_parameters(),
trainer_args=trainer_args,
)
@torch.inference_mode()
def generate(self, token_ids: list[int], max_new_tokens: int) -> list[int]:
"""Greedy-decode max_new_tokens from the given context."""
input_ids = torch.tensor([token_ids], device="cuda:0")
output = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=False,
)
new_token_ids = output[0, len(token_ids) :].tolist()
return new_token_ids
# Build platform-specific env vars for Ray
ray_env_vars = {
# Prevent Ray from setting CUDA_VISIBLE_DEVICES
"RAY_EXPERIMENTAL_NOSET_CUDA_ENV_VAR": "1",
}
# Initialize Ray and set the visible devices. The vLLM engine will
# be placed on GPUs 1 and 2.
ray.init()
if current_platform.is_rocm():
# For ROCm, BATCH_INVARIANT vllm is not supported
ray_env_vars["VLLM_ROCM_USE_SKINNY_GEMM"] = "0"
else:
# Enable batch invariance for deterministic outputs on NVIDIA
ray_env_vars["VLLM_BATCH_INVARIANT"] = "1"
ray.init(runtime_env={"env_vars": ray_env_vars})
# Launch the training model actor. Ray's resource scheduler will allocate
# 1 GPU (via num_gpus=1 in the decorator), ensuring pg_inference gets different GPUs.
train_model = TrainModel.remote(MODEL_NAME)
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/placement-groups.html
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group=pg_inference,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=0,
train_model = TrainModel.remote(MODEL_NAME_V2)
rocm_determinism_kwargs = {}
if current_platform.is_rocm():
# ROCm: To minimize non-determinism, we set fixed seed, no prefix caching, and
# sequential request processing (max_num_seqs=1).
rocm_determinism_kwargs = {
"seed": 0,
"enable_prefix_caching": False,
"max_num_seqs": 1,
}
# Build platform-specific LLM kwargs
llm_kwargs = dict(
model=MODEL_NAME_V1,
enforce_eager=True,
max_model_len=8192,
distributed_executor_backend="ray",
attention_backend=ATTN_BACKEND,
gpu_memory_utilization=0.75,
weight_transfer_config=WeightTransferConfig(backend="nccl"),
)
llm_kwargs.update(rocm_determinism_kwargs)
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
# start-up latency.
# Note: Weight transfer APIs (init_weight_transfer_engine, update_weights)
# are now native to vLLM workers.
# Launch the vLLM inference engine.
# With data_parallel_backend="ray", vLLM's CoreEngineActorManager creates
# its own placement groups internally for each DP rank, so we must NOT
# create an outer placement group (it would reserve GPUs and hide them
# from the internal DP resource check).
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=scheduling_inference,
)(MyLLM).remote(
model=MODEL_NAME,
enforce_eager=True,
tensor_parallel_size=2,
distributed_executor_backend="ray",
load_format="dummy",
weight_transfer_config=WeightTransferConfig(backend="nccl"),
)
)(MyLLM).remote(**llm_kwargs)
# Generate text from the prompts.
prompts = [
"My name is",
PROMPTS = [
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"The largest ocean on Earth is",
"The speed of light in a vacuum is",
"The chemical formula for water is",
"The tallest mountain in the world is",
"The first person to walk on the moon was",
"The Great Wall of China was built to",
"Photosynthesis is the process by which",
"The theory of general relativity was proposed by",
"The boiling point of water at sea level is",
"The largest planet in our solar system is",
"DNA stands for deoxyribonucleic acid and it",
]
# Tokenize prompts to token IDs
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
prompt_token_ids_list = [
tokenizer.encode(prompt, add_special_tokens=False) for prompt in prompts
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_V1)
batch_prompt_token_ids = [
tokenizer.encode(prompt, add_special_tokens=False) for prompt in PROMPTS
]
sampling_params = [
SamplingParams(temperature=0, max_tokens=2),
SamplingParams(temperature=0, max_tokens=32),
SamplingParams(temperature=0, max_tokens=32),
SamplingParams(temperature=0, max_tokens=32),
]
# Set up the communication channel between the training process and the
# inference engine.
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
world_size = 3 # 1 trainer + 2 inference workers (tensor_parallel_size=2)
world_size = 2 # 1 trainer + 1 inference worker
inference_handle = llm.init_weight_transfer_engine.remote(
WeightTransferInitRequest(
init_info=asdict(
......@@ -215,22 +279,28 @@ train_handle = train_model.init_weight_transfer_group.remote(world_size)
ray.get([train_handle, inference_handle])
generation_futures = [
llm.generate_with_retry.remote(prompt_token_ids, params)
for prompt_token_ids, params in zip(prompt_token_ids_list, sampling_params)
]
N_NEW_TOKENS = 100
finished, pending = ray.wait(generation_futures, num_returns=1)
# Collect weight metadata once
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
# Pause generation in preparation for weight sync
ray.get(llm.pause_generation.remote(wait_for_inflight_requests=False))
# ── Phase 1: concurrent requests with weight sync ───────────────────
print(f"\n{'=' * 50}")
print(f"Prompts ({len(PROMPTS)}):")
for p in PROMPTS:
print(f" - {p!r}")
print(f"{'=' * 50}")
# Synchronize the updated weights to the inference engine using batched API.
# Collect all weight metadata from the training actor
names, dtype_names, shapes = ray.get(train_model.get_weight_metadata.remote())
sampling_params = SamplingParams(
temperature=0, max_tokens=PAUSE_TOKEN_THRESHOLD + N_NEW_TOKENS
)
gen_futures = [
llm.do_generate.remote(ptids, sampling_params) for ptids in batch_prompt_token_ids
]
ray.get(llm.pause_after_n_tokens.remote())
# Issue update_weights call with NCCL-specific update info
# packed=True enables efficient batched tensor broadcasting
inference_handle = llm.update_weights.remote(
WeightTransferUpdateRequest(
update_info=asdict(
......@@ -243,41 +313,103 @@ inference_handle = llm.update_weights.remote(
)
)
)
# Broadcast all weights from trainer using the weight transfer API
train_handle = train_model.broadcast_weights.remote(packed=True)
ray.get([train_handle, inference_handle])
# Resume generation since weight sync is complete
ray.get(llm.resume_generation.remote())
results = ray.get(gen_futures)
for i, (output, pause_idx) in enumerate(results):
all_token_ids = list(output.outputs[0].token_ids)
before_text = tokenizer.decode(all_token_ids[:pause_idx])
after_text = tokenizer.decode(all_token_ids[pause_idx:])
print(f"\n Request {i} ({PROMPTS[i]!r}):")
print(f" Old weights ({pause_idx} tokens): {before_text!r}")
n_after = len(all_token_ids) - pause_idx
print(f" New weights ({n_after} tokens): {after_text!r}")
# ── Phase 2: validate with a fresh V2 vLLM instance ────────────────
# This validation relies on batch-invariant (deterministic) generation to
# compare outputs from the weight-synced engine against a fresh V2 instance.
# On NVIDIA, batch invariance is fully supported, so we require 100% exact
# token match. On ROCm, batch invariance is not yet fully implemented
# (see https://github.com/vllm-project/vllm/issues/27433 and
# https://github.com/vllm-project/vllm/issues/33123), so residual
# non-determinism (e.g. GEMM accumulation order, missing kernel overrides)
# can cause single-token divergences that don't indicate a weight-sync
# failure. We relax the pass rate to 90% on ROCm to accommodate this; a
# real regression (broken weight transfer) would cause ~0% pass rate, not 90%+.
MIN_PASS_RATE = 1.0 if not current_platform.is_rocm() else 0.9
print(f"\n{'=' * 50}")
print("VALIDATION: comparing weight-synced vLLM with fresh V2 instance")
if current_platform.is_rocm():
print(f" (ROCm mode: requiring >= {MIN_PASS_RATE:.0%} exact match rate)")
print(f"{'=' * 50}")
ray.get(llm.shutdown.remote())
ray.kill(llm)
ray.kill(train_model)
llm_v2_kwargs = dict(
model=MODEL_NAME_V2,
enforce_eager=True,
max_model_len=8192,
gpu_memory_utilization=0.75,
distributed_executor_backend="ray",
attention_backend=ATTN_BACKEND,
)
llm_v2_kwargs.update(rocm_determinism_kwargs)
# Get outputs separately - finished completed before pause, pending were paused/resumed
finished_outputs = ray.get(finished)
pending_outputs = ray.get(pending)
# Requests that finished before the pause: all generation used original weights
print("-" * 50)
print("Requests that completed BEFORE weight change:")
print("-" * 50)
for output in finished_outputs:
prompt_text = tokenizer.decode(output.prompt_token_ids)
print(f"Prompt: {prompt_text!r}")
print(f"Generated (with original weights): {output.outputs[0].text!r}")
print("-" * 50)
# Requests that were paused mid-generation: some text before, some after weight change
print("Requests that were PAUSED and RESUMED after weight change:")
print("-" * 50)
for output in pending_outputs:
# Decode the full prompt token IDs (original + generated before pause)
full_prompt_text = tokenizer.decode(output.prompt_token_ids)
# Find the original prompt by checking which one this output started with
original_prompt = next(p for p in prompts if full_prompt_text.startswith(p))
# output.prompt_token_ids contains original prompt + tokens generated before pause
# output.outputs[0].text is what was generated after resuming with new weights
text_before_pause = full_prompt_text[len(original_prompt) :]
text_after_pause = output.outputs[0].text
print(f"Original prompt: {original_prompt!r}")
print(f"Generated before weight change: {text_before_pause!r}")
print(f"Generated after weight change: {text_after_pause!r}")
print("-" * 50)
llm_v2 = ray.remote(
num_cpus=0,
num_gpus=0,
)(MyLLM).remote(**llm_v2_kwargs)
val_futures = [
llm_v2.do_generate.remote(
list(output.prompt_token_ids) + list(output.outputs[0].token_ids)[:pause_idx],
SamplingParams(
temperature=0, max_tokens=len(output.outputs[0].token_ids) - pause_idx
),
)
for output, pause_idx in results
]
val_results = ray.get(val_futures)
num_pass = 0
num_total = len(results)
for i, ((output, pause_idx), (val_output, _)) in enumerate(zip(results, val_results)):
expected = list(output.outputs[0].token_ids)[pause_idx:]
actual = list(val_output.outputs[0].token_ids)
match = actual == expected
if match:
num_pass += 1
print(f" [PASS] {PROMPTS[i]!r}")
else:
print(f" [FAIL] {PROMPTS[i]!r}")
print(f" weight-synced vLLM: {tokenizer.decode(expected)!r}")
print(f" V2 vLLM: {tokenizer.decode(actual)!r}")
for j, (e, a) in enumerate(zip(expected, actual)):
if e != a:
print(
f" first divergence at output token {j}: "
f"expected {e} ({tokenizer.decode([e])!r}) vs "
f"actual {a} ({tokenizer.decode([a])!r})"
)
break
ray.get(llm_v2.shutdown.remote())
ray.kill(llm_v2)
pass_rate = num_pass / num_total
print(f"\n Result: {num_pass}/{num_total} prompts passed ({pass_rate:.0%})")
print(f" Required: >= {MIN_PASS_RATE:.0%}")
assert pass_rate >= MIN_PASS_RATE, (
f"Validation pass rate {pass_rate:.0%} ({num_pass}/{num_total}) "
f"is below the required {MIN_PASS_RATE:.0%} threshold. "
f"See failures above for details."
)
print("=" * 50)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray,
with IPC-based weight syncing APIs
The script colocates the training and inference workloads onto the same GPU using Ray.
The example performs the following steps:
* Request a placement group of 1 GPU.
* Place the inference model on the above GPU using the placement group.
* Place and load the training model on the same GPU using the placement group.
* Generate text from a list of prompts using the inference engine.
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using CUDA IPC handles. Note that
for demonstration purposes we simply zero out the weights.
This example assumes a single-node cluster with a single GPU,
but can be extended to multiple GPUs.
"""
import os
import ray
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.ipc_engine import (
IPCTrainerSendWeightsArgs,
IPCWeightTransferEngine,
)
class MyLLM(LLM):
"""Configure the vLLM worker for Ray placement group execution."""
def __init__(self, *args, **kwargs):
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
# so that vLLM can manage its own device placement within the worker.
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
# Each worker uses 0.4 GPU so that two instances fit on the same GPU.
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
os.environ["VLLM_RAY_BUNDLE_INDICES"] = "0"
# needed for ipc handle serialization
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
super().__init__(*args, **kwargs)
# Load the OPT-125M model onto GPU 0 for the training workload.
MODEL_NAME = "facebook/opt-125m"
@ray.remote
class TrainModel:
def __init__(self, llm_handle: ray.actor.ActorHandle):
self.train_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
)
self.train_model.to("cuda:0")
self.llm_handle = llm_handle
def init_weight_transfer(self):
# IPC backend doesn't need initialization info
ray.get(
self.llm_handle.init_weight_transfer_engine.remote(dict(init_info=dict()))
)
def broadcast_weights(self, llm_handle: ray.actor.ActorHandle):
"""Broadcast weights to the inference engine using IPC."""
self.llm_handle = llm_handle
trainer_args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
IPCWeightTransferEngine.trainer_send_weights(
iterator=self.train_model.named_parameters(),
trainer_args=trainer_args,
)
ray.init()
pg_colocate = placement_group([{"GPU": 1, "CPU": 0}])
ray.get(pg_colocate.ready())
llm = ray.remote(
num_cpus=0,
num_gpus=0,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg_colocate,
placement_group_capture_child_tasks=True,
),
)(MyLLM).remote(
model=MODEL_NAME,
enforce_eager=True,
tensor_parallel_size=1,
distributed_executor_backend="ray",
gpu_memory_utilization=0.7,
weight_transfer_config=WeightTransferConfig(backend="ipc"),
load_format="dummy",
)
train_model = TrainModel.options(
num_gpus=0.1,
num_cpus=0,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg_colocate, placement_group_capture_child_tasks=True
),
).remote(llm)
# Generate text from the prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
outputs = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
ray.get(llm.sleep.remote(level=0))
ray.get(train_model.init_weight_transfer.remote())
# Synchronize the updated weights to the inference engine using batched API.
ray.get(train_model.broadcast_weights.remote(llm))
ray.get(llm.wake_up.remote(tags=["scheduling"]))
# Generate text with the updated model.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
print("-" * 50)
for output in outputs_updated:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
......@@ -36,6 +36,7 @@ from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.config import WeightTransferConfig
from vllm.distributed.weight_transfer.nccl_engine import (
NCCLTrainerSendWeightsArgs,
NCCLWeightTransferEngine,
)
from vllm.utils.network_utils import get_ip, get_open_port
......@@ -90,11 +91,14 @@ class TrainModel:
def broadcast_weights(self, packed: bool = True):
"""Broadcast weights to the inference engine."""
NCCLWeightTransferEngine.trainer_send_weights(
iterator=self.model.named_parameters(),
trainer_args = NCCLTrainerSendWeightsArgs(
group=self.model_update_group,
packed=packed,
)
NCCLWeightTransferEngine.trainer_send_weights(
iterator=self.model.named_parameters(),
trainer_args=trainer_args,
)
# Initialize Ray and set the visible devices. The vLLM engine will
......@@ -156,6 +160,8 @@ for output in outputs:
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
ray.get(llm.sleep.remote(level=0))
# Set up the communication channel between the training process and the
# inference engine.
master_address, master_port = ray.get(train_model.get_master_address_and_port.remote())
......@@ -197,6 +203,8 @@ inference_handle = llm.update_weights.remote(
train_handle = train_model.broadcast_weights.remote(packed=True)
ray.get([train_handle, inference_handle])
ray.get(llm.wake_up.remote(tags=["scheduling"]))
# Generate text with the updated model. The output is expected to be normal
# because the weights are updated.
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use FlexKV with vLLM for prefix caching.
FlexKV is a distributed KV Store and multi-level cache management system for
ultra-large-scale LLM inference.
Requirements:
- Install FlexKV (https://github.com/taco-project/FlexKV):
1. git clone git@github.com:taco-project/FlexKV.git
2. cd FlexKV && bash build.sh
- Ensure FlexKV is compatible with your vLLM version.
Usage:
1. Run this script:
python examples/offline_inference/prefix_caching_flexkv.py \
--model /path/to/your/model
2. Arguments:
--model Path or name of the model (required)
--tp-size Tensor parallel size (default: 1)
--gpu-memory-util GPU memory utilization (default: 0.4)
3. The script will:
- Create a FlexKV configuration file.
- Set the FLEXKV_CONFIG_PATH environment variable.
- Run vLLM with FlexKVConnectorV1 enabled.
- Compare results between regular execution, vLLM's default prefix
caching, and FlexKV.
"""
import argparse
import json
import os
import time
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
# NOTE: This is just a running example. For benchmarking purpose,
# please see benchmarks/benchmark_prefix_caching.py
def parse_args():
parser = argparse.ArgumentParser(
description="Example of using FlexKV with vLLM for prefix caching."
)
parser.add_argument(
"--model",
type=str,
required=True,
help="Path or name of the model to use.",
)
parser.add_argument(
"--tp-size",
type=int,
default=1,
help="Tensor parallel size (default: 1).",
)
parser.add_argument(
"--gpu-memory-util",
type=float,
default=0.4,
help="GPU memory utilization fraction (default: 0.4).",
)
return parser.parse_args()
def main():
args = parse_args()
flexkv_config = {
"server_recv_port": f"ipc:///tmp/flexkv_test_{os.getpid()}",
"cache_config": {
"enable_cpu": True,
"num_cpu_blocks": 10240,
},
"num_log_interval_requests": 200,
}
flexkv_config_path = f"./flexkv_config_{os.getpid()}.json"
with open(flexkv_config_path, "w") as f:
json.dump(flexkv_config, f)
os.environ["FLEXKV_CONFIG_PATH"] = flexkv_config_path
try:
_run(args)
finally:
if os.path.exists(flexkv_config_path):
os.remove(flexkv_config_path)
def _run(args):
# Common prefix.
prefix = (
"You are an expert school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
"community, joyful discovery, and life-long learning. The candidate is "
"coming in for a first-round panel interview for a 8th grade Math "
"teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on these information, fulfill "
"the following paragraph: "
)
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
generating_prompts = [prefix + prompt for prompt in prompts]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0)
kv_transfer_config = {
"kv_connector": "FlexKVConnectorV1",
"kv_role": "kv_both",
}
# Create an LLM without prefix caching as a baseline.
regular_llm = LLM(
model=args.model,
enable_prefix_caching=False,
gpu_memory_utilization=args.gpu_memory_util,
tensor_parallel_size=args.tp_size,
)
print("Results without `enable_prefix_caching`")
# ruff: noqa: E501
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
outputs = regular_llm.generate(generating_prompts, sampling_params)
regular_generated_texts = []
# Print the outputs.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
regular_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Destroy the LLM object and free up the GPU memory.
del regular_llm
cleanup_dist_env_and_memory()
# Create an LLM with prefix caching enabled.
prefix_cached_llm = LLM(
model=args.model,
enable_prefix_caching=True,
gpu_memory_utilization=args.gpu_memory_util,
tensor_parallel_size=args.tp_size,
kv_transfer_config=kv_transfer_config,
)
# Warmup so that the shared prompt's KV cache is computed.
prefix_cached_llm.generate(generating_prompts[0], sampling_params)
# wait for offload kv task finished.
time.sleep(2)
# Generate with prefix caching.
outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)
print("Results with `enable_prefix_caching`")
cached_generated_texts = []
# Print the outputs. You should see the same outputs as before.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
cached_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Compare the results and display the speedup
generated_same = all(
regular_generated_texts[i] == cached_generated_texts[i]
for i in range(len(prompts))
)
print(f"Generated answers are the same: {generated_same}")
# wait for offload kv task finished.
time.sleep(2)
# reset prefix cache to use flexkv
prefix_cached_llm.reset_prefix_cache()
# Generate with prefix caching.
outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)
print("Results with `flexkv`")
flexkv_generated_texts = []
# Print the outputs. You should see the same outputs as before.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
flexkv_generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Compare the results and display the speedup
generated_same = all(
regular_generated_texts[i] == flexkv_generated_texts[i]
for i in range(len(prompts))
)
print(f"Generated answers are the same: {generated_same}")
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment