Commit 3fb4b5fa authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.18.0' into v0.18.0-ori

parents bcf25339 89138b21
...@@ -37,6 +37,12 @@ class BlockStored(KVCacheEvent): ...@@ -37,6 +37,12 @@ class BlockStored(KVCacheEvent):
medium: str | None medium: str | None
lora_name: str | None lora_name: str | None
extra_keys: list[tuple[Any, ...] | None] | None = None
"""Extra keys used in block hash computation, one entry per block in
block_hashes. Each entry contains MM identifiers, LoRA name, cache_salt,
prompt embeddings data, etc. for that specific block.
"""
class BlockRemoved(KVCacheEvent): class BlockRemoved(KVCacheEvent):
block_hashes: list[ExternalBlockHash] block_hashes: list[ExternalBlockHash]
......
...@@ -57,8 +57,7 @@ case "$subcommand" in ...@@ -57,8 +57,7 @@ case "$subcommand" in
# Retry until the worker node connects to the head node or the timeout expires. # Retry until the worker node connects to the head node or the timeout expires.
for (( i=0; i < $ray_init_timeout; i+=5 )); do for (( i=0; i < $ray_init_timeout; i+=5 )); do
ray start --address=$ray_address:$ray_port --block "${start_params[@]}" if ray start --address="$ray_address":"$ray_port" --block "${start_params[@]}"; then
if [ $? -eq 0 ]; then
echo "Worker: Ray runtime started with head address $ray_address:$ray_port" echo "Worker: Ray runtime started with head address $ray_address:$ray_port"
exit 0 exit 0
fi fi
...@@ -95,12 +94,12 @@ case "$subcommand" in ...@@ -95,12 +94,12 @@ case "$subcommand" in
fi fi
# Start the Ray head node. # Start the Ray head node.
ray start --head --port=$ray_port "${start_params[@]}" ray start --head --port="$ray_port" "${start_params[@]}"
# Poll Ray until every worker node is active. # Poll Ray until every worker node is active.
for (( i=0; i < $ray_init_timeout; i+=5 )); do for (( i=0; i < $ray_init_timeout; i+=5 )); do
active_nodes=`python3 -c 'import ray; ray.init(); print(sum(node["Alive"] for node in ray.nodes()))'` active_nodes=$(python3 -c 'import ray; ray.init(); print(sum(node["Alive"] for node in ray.nodes()))')
if [ $active_nodes -eq $ray_cluster_size ]; then if [ "$active_nodes" -eq "$ray_cluster_size" ]; then
echo "All ray workers are active and the ray cluster is initialized successfully." echo "All ray workers are active and the ray cluster is initialized successfully."
exit 0 exit 0
fi fi
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM
via HTTP API, with IPC-based weight syncing APIs.
Unlike rlhf_nccl.py which uses NCCL and can use separate GPUs, this script
uses CUDA IPC which requires the training model and vLLM server to be on the
same GPU. Memory must be carefully managed to fit both models.
Unlike rlhf.py which creates a vLLM instance programmatically, this script
assumes you have already started a vLLM server using `vllm serve`. It uses:
- OpenAI-compatible API for inference requests
- HTTP endpoints for weight transfer control plane
- CUDA IPC for actual weight data transfer
Prerequisites:
Start a vLLM server with weight transfer enabled and reduced GPU memory
utilization to leave room for the training model:
$ VLLM_SERVER_DEV_MODE=1 VLLM_ALLOW_INSECURE_SERIALIZATION=1 \
vllm serve facebook/opt-125m --enforce-eager \
--weight-transfer-config '{"backend": "ipc"}' \
--load-format dummy \
--gpu-memory-utilization 0.5
Then run this script:
$ python rlhf_http_ipc.py
The example performs the following steps:
* Load the training model on GPU 0 (same GPU as the vLLM server).
* Generate text using the vLLM server via OpenAI-compatible API. The output
is expected to be nonsense because the server is initialized with dummy weights.
* Initialize weight transfer via HTTP endpoint (no-op for IPC).
* Broadcast the real weights from the training model to the vLLM server
using CUDA IPC handles.
* Generate text again to show normal output after the weight update.
"""
import os
import requests
import torch
from openai import OpenAI
from transformers import AutoModelForCausalLM
from vllm.distributed.weight_transfer.ipc_engine import (
IPCTrainerSendWeightsArgs,
IPCWeightTransferEngine,
)
BASE_URL = "http://localhost:8000"
MODEL_NAME = "facebook/opt-125m"
# Enable insecure serialization for IPC handle serialization
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]:
"""Generate completions using the OpenAI-compatible API."""
results = []
for prompt in prompts:
response = client.completions.create(
model=model,
prompt=prompt,
max_tokens=32,
temperature=0,
)
results.append(response.choices[0].text)
return results
def init_weight_transfer_engine(base_url: str) -> None:
"""Initialize weight transfer via HTTP endpoint (no-op for IPC)."""
url = f"{base_url}/init_weight_transfer_engine"
payload = {"init_info": dict()}
response = requests.post(url, json=payload, timeout=60)
response.raise_for_status()
def pause_generation(base_url: str) -> None:
"""Pause generation via HTTP endpoint."""
url = f"{base_url}/pause"
response = requests.post(url, timeout=60)
response.raise_for_status()
def resume_generation(base_url: str) -> None:
"""Resume generation via HTTP endpoint."""
url = f"{base_url}/resume"
response = requests.post(url, timeout=60)
response.raise_for_status()
def get_world_size(base_url: str) -> int:
"""Get world size from the vLLM server."""
url = f"{base_url}/get_world_size"
response = requests.get(url, timeout=10)
response.raise_for_status()
return response.json()["world_size"]
def main():
# IPC requires the training model to be on the same GPU as the vLLM server
# The server should be started on GPU 0 with reduced memory utilization
device = "cuda:0"
torch.accelerator.set_device_index(device)
# Load the training model on the same GPU as the server
# Use bfloat16 to reduce memory footprint
print(f"Loading training model: {MODEL_NAME} on {device}")
print(
"Note: Ensure the vLLM server was started with --gpu-memory-utilization 0.5 "
"or lower to leave room for the training model."
)
train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16)
train_model.to(device)
train_model.eval() # Set to eval mode to save memory
# Create OpenAI client pointing to the vLLM server
client = OpenAI(
base_url=f"{BASE_URL}/v1",
api_key="EMPTY", # vLLM doesn't require an API key by default
)
# Test prompts
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Generate text before weight update. The output is expected to be nonsense
# because the server is initialized with dummy weights.
print("-" * 50)
print("Generating text BEFORE weight update (expect nonsense):")
print("-" * 50)
outputs = generate_completions(client, MODEL_NAME, prompts)
for prompt, generated_text in zip(prompts, outputs):
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
print("Initializing weight transfer (IPC backend)...")
# Initialize weight transfer on vLLM server (no-op for IPC, but still required)
init_weight_transfer_engine(BASE_URL)
# Pause generation before weight sync
pause_generation(BASE_URL)
# Broadcast weights via IPC handles using HTTP mode
print("Broadcasting weights via CUDA IPC (HTTP)...")
trainer_args = IPCTrainerSendWeightsArgs(mode="http", url=BASE_URL)
IPCWeightTransferEngine.trainer_send_weights(
iterator=train_model.named_parameters(),
trainer_args=trainer_args,
)
# Resume generation after weight sync
resume_generation(BASE_URL)
# Generate text after weight update. The output is expected to be normal
# because the real weights are now loaded.
print("-" * 50)
print("Generating text AFTER weight update:")
print("-" * 50)
outputs_updated = generate_completions(client, MODEL_NAME, prompts)
for prompt, generated_text in zip(prompts, outputs_updated):
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
# Note: The training model and IPC handles remain in memory.
# In a real RLHF training loop, you would update the training model
# and create new IPC handles for each weight update.
if __name__ == "__main__":
main()
...@@ -39,6 +39,7 @@ from openai import OpenAI ...@@ -39,6 +39,7 @@ from openai import OpenAI
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from vllm.distributed.weight_transfer.nccl_engine import ( from vllm.distributed.weight_transfer.nccl_engine import (
NCCLTrainerSendWeightsArgs,
NCCLWeightTransferEngine, NCCLWeightTransferEngine,
) )
from vllm.utils.network_utils import get_ip, get_open_port from vllm.utils.network_utils import get_ip, get_open_port
...@@ -130,7 +131,7 @@ def main(): ...@@ -130,7 +131,7 @@ def main():
inference_world_size = get_world_size(BASE_URL) inference_world_size = get_world_size(BASE_URL)
world_size = inference_world_size + 1 # +1 for the trainer world_size = inference_world_size + 1 # +1 for the trainer
device = f"cuda:{inference_world_size}" device = f"cuda:{inference_world_size}"
torch.cuda.set_device(device) torch.accelerator.set_device_index(device)
# Load the training model # Load the training model
print(f"Loading training model: {MODEL_NAME}") print(f"Loading training model: {MODEL_NAME}")
...@@ -214,11 +215,14 @@ def main(): ...@@ -214,11 +215,14 @@ def main():
# Broadcast all weights from trainer to vLLM workers # Broadcast all weights from trainer to vLLM workers
print("Broadcasting weights via NCCL...") print("Broadcasting weights via NCCL...")
NCCLWeightTransferEngine.trainer_send_weights( trainer_args = NCCLTrainerSendWeightsArgs(
iterator=train_model.named_parameters(),
group=model_update_group, group=model_update_group,
packed=True, packed=True,
) )
NCCLWeightTransferEngine.trainer_send_weights(
iterator=train_model.named_parameters(),
trainer_args=trainer_args,
)
# Wait for update_weights to complete # Wait for update_weights to complete
update_thread.join() update_thread.join()
......
...@@ -10,7 +10,7 @@ vllm serve llava-hf/llava-1.5-7b-hf ...@@ -10,7 +10,7 @@ vllm serve llava-hf/llava-1.5-7b-hf
(multi-image inference with Phi-3.5-vision-instruct) (multi-image inference with Phi-3.5-vision-instruct)
vllm serve microsoft/Phi-3.5-vision-instruct --runner generate \ vllm serve microsoft/Phi-3.5-vision-instruct --runner generate \
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}' --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt.image 2
(audio inference with Ultravox) (audio inference with Ultravox)
vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b \ vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b \
......
...@@ -26,7 +26,9 @@ from openai import AsyncOpenAI, OpenAI ...@@ -26,7 +26,9 @@ from openai import AsyncOpenAI, OpenAI
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
def sync_openai(audio_path: str, client: OpenAI, model: str): def sync_openai(
audio_path: str, client: OpenAI, model: str, *, repetition_penalty: float = 1.3
):
""" """
Perform synchronous transcription using OpenAI-compatible API. Perform synchronous transcription using OpenAI-compatible API.
""" """
...@@ -40,7 +42,7 @@ def sync_openai(audio_path: str, client: OpenAI, model: str): ...@@ -40,7 +42,7 @@ def sync_openai(audio_path: str, client: OpenAI, model: str):
# Additional sampling params not provided by OpenAI API. # Additional sampling params not provided by OpenAI API.
extra_body=dict( extra_body=dict(
seed=4419, seed=4419,
repetition_penalty=1.3, repetition_penalty=repetition_penalty,
), ),
) )
print("transcription result [sync]:", transcription.text) print("transcription result [sync]:", transcription.text)
...@@ -129,7 +131,12 @@ def main(args): ...@@ -129,7 +131,12 @@ def main(args):
print(f"Using model: {model}") print(f"Using model: {model}")
# Run the synchronous function # Run the synchronous function
sync_openai(args.audio_path if args.audio_path else mary_had_lamb, client, model) sync_openai(
audio_path=args.audio_path if args.audio_path else mary_had_lamb,
client=client,
model=model,
repetition_penalty=args.repetition_penalty,
)
# Run the asynchronous function # Run the asynchronous function
if "openai" in model: if "openai" in model:
...@@ -161,5 +168,11 @@ if __name__ == "__main__": ...@@ -161,5 +168,11 @@ if __name__ == "__main__":
default=None, default=None,
help="The path to the audio file to transcribe.", help="The path to the audio file to transcribe.",
) )
parser.add_argument(
"--repetition_penalty",
type=float,
default=1.3,
help="repetition penalty",
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
# Setup OpenTelemetry POC # Setup OpenTelemetry POC
1. Install OpenTelemetry packages: > **Note:** The core OpenTelemetry packages (`opentelemetry-sdk`, `opentelemetry-api`, `opentelemetry-exporter-otlp`, `opentelemetry-semantic-conventions-ai`) are bundled with vLLM. Manual installation is not required.
```bash
pip install \
'opentelemetry-sdk>=1.26.0,<1.27.0' \
'opentelemetry-api>=1.26.0,<1.27.0' \
'opentelemetry-exporter-otlp>=1.26.0,<1.27.0' \
'opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0'
```
1. Start Jaeger in a docker container: 1. Start Jaeger in a docker container:
......
...@@ -22,11 +22,10 @@ check_hf_token() { ...@@ -22,11 +22,10 @@ check_hf_token() {
check_num_gpus() { check_num_gpus() {
# can you check if the number of GPUs are >=2 via nvidia-smi/rocm-smi? # can you check if the number of GPUs are >=2 via nvidia-smi/rocm-smi?
which rocm-smi > /dev/null 2>&1 if ! which rocm-smi > /dev/null 2>&1; then
if [ $? -ne 0 ]; then
num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
else else
num_gpus=$(rocm-smi --showid | grep Instinct | wc -l) num_gpus=$(rocm-smi --showid | grep -c Instinct)
fi fi
if [ "$num_gpus" -lt 2 ]; then if [ "$num_gpus" -lt 2 ]; then
...@@ -39,8 +38,7 @@ check_num_gpus() { ...@@ -39,8 +38,7 @@ check_num_gpus() {
ensure_python_library_installed() { ensure_python_library_installed() {
echo "Checking if $1 is installed..." echo "Checking if $1 is installed..."
python3 -c "import $1" > /dev/null 2>&1 if ! python3 -c "import $1" > /dev/null 2>&1; then
if [ $? -ne 0 ]; then
if [ "$1" == "nixl" ]; then if [ "$1" == "nixl" ]; then
echo "$1 is not installed. Please refer to https://github.com/ai-dynamo/nixl for installation." echo "$1 is not installed. Please refer to https://github.com/ai-dynamo/nixl for installation."
else else
...@@ -102,12 +100,12 @@ main() { ...@@ -102,12 +100,12 @@ main() {
bash disagg_vllm_launcher.sh prefiller \ bash disagg_vllm_launcher.sh prefiller \
> >(tee prefiller.log) 2>&1 & > >(tee prefiller.log) 2>&1 &
prefiller_pid=$! prefiller_pid=$!
PIDS+=($prefiller_pid) PIDS+=("$prefiller_pid")
bash disagg_vllm_launcher.sh decoder \ bash disagg_vllm_launcher.sh decoder \
> >(tee decoder.log) 2>&1 & > >(tee decoder.log) 2>&1 &
decoder_pid=$! decoder_pid=$!
PIDS+=($decoder_pid) PIDS+=("$decoder_pid")
python3 disagg_proxy_server.py \ python3 disagg_proxy_server.py \
--host localhost \ --host localhost \
...@@ -118,7 +116,7 @@ main() { ...@@ -118,7 +116,7 @@ main() {
--decoder-port 8200 \ --decoder-port 8200 \
> >(tee proxy.log) 2>&1 & > >(tee proxy.log) 2>&1 &
proxy_pid=$! proxy_pid=$!
PIDS+=($proxy_pid) PIDS+=("$proxy_pid")
wait_for_server 8100 wait_for_server 8100
wait_for_server 8200 wait_for_server 8200
...@@ -128,7 +126,7 @@ main() { ...@@ -128,7 +126,7 @@ main() {
# begin benchmark # begin benchmark
cd ../../../../benchmarks/ cd ../../../../benchmarks/
vllm bench serve --port 9000 --seed $(date +%s) \ vllm bench serve --port 9000 --seed "$(date +%s)" \
--model meta-llama/Llama-3.1-8B-Instruct \ --model meta-llama/Llama-3.1-8B-Instruct \
--dataset-name random --random-input-len 7500 --random-output-len 200 \ --dataset-name random --random-input-len 7500 --random-output-len 200 \
--num-prompts 200 --burstiness 100 --request-rate 3.6 | tee benchmark.log --num-prompts 200 --burstiness 100 --request-rate 3.6 | tee benchmark.log
......
...@@ -34,7 +34,7 @@ if [[ $1 == "prefiller" ]]; then ...@@ -34,7 +34,7 @@ if [[ $1 == "prefiller" ]]; then
VLLM_ENABLE_V1_MULTIPROCESSING=1 \ VLLM_ENABLE_V1_MULTIPROCESSING=1 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \ VLLM_WORKER_MULTIPROC_METHOD=spawn \
CUDA_VISIBLE_DEVICES=0 \ CUDA_VISIBLE_DEVICES=0 \
vllm serve $MODEL \ vllm serve "$MODEL" \
--port 8100 \ --port 8100 \
--enforce-eager \ --enforce-eager \
--kv-transfer-config \ --kv-transfer-config \
...@@ -51,7 +51,7 @@ elif [[ $1 == "decoder" ]]; then ...@@ -51,7 +51,7 @@ elif [[ $1 == "decoder" ]]; then
VLLM_ENABLE_V1_MULTIPROCESSING=1 \ VLLM_ENABLE_V1_MULTIPROCESSING=1 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \ VLLM_WORKER_MULTIPROC_METHOD=spawn \
CUDA_VISIBLE_DEVICES=1 \ CUDA_VISIBLE_DEVICES=1 \
vllm serve $MODEL \ vllm serve "$MODEL" \
--port 8200 \ --port 8200 \
--enforce-eager \ --enforce-eager \
--kv-transfer-config \ --kv-transfer-config \
......
...@@ -7,8 +7,8 @@ NOTE: ...@@ -7,8 +7,8 @@ NOTE:
vllm serve muziyongshixin/Qwen2.5-VL-7B-for-VideoCls \ vllm serve muziyongshixin/Qwen2.5-VL-7B-for-VideoCls \
--runner pooling \ --runner pooling \
--max-model-len 5000 \ --max-model-len 5000 \
--limit-mm-per-prompt '{"video": 1}' \ --limit-mm-per-prompt.video 1 \
--hf-overrides '{"text_config": {"architectures": ["Qwen2_5_VLForSequenceClassification"]}}' --hf-overrides '{"architectures": ["Qwen2_5_VLForSequenceClassification"]}'
""" """
import argparse import argparse
......
...@@ -34,7 +34,7 @@ python client.py ...@@ -34,7 +34,7 @@ python client.py
## 📁 Files ## 📁 Files
| File | Description | | File | Description |
|------|-------------| | ---- | ----------- |
| `service.sh` | Server startup script with chunked processing enabled | | `service.sh` | Server startup script with chunked processing enabled |
| `client.py` | Comprehensive test client for long text embedding | | `client.py` | Comprehensive test client for long text embedding |
...@@ -61,7 +61,7 @@ The key parameters for chunked processing are in the `--pooler-config`: ...@@ -61,7 +61,7 @@ The key parameters for chunked processing are in the `--pooler-config`:
Chunked processing uses **MEAN aggregation** for cross-chunk combination when input exceeds the model's native maximum length: Chunked processing uses **MEAN aggregation** for cross-chunk combination when input exceeds the model's native maximum length:
| Component | Behavior | Description | | Component | Behavior | Description |
|-----------|----------|-------------| | --------- | -------- | ----------- |
| **Within chunks** | Model's native pooling | Uses the model's configured pooling strategy | | **Within chunks** | Model's native pooling | Uses the model's configured pooling strategy |
| **Cross-chunk aggregation** | Always MEAN | Weighted averaging based on chunk token counts | | **Cross-chunk aggregation** | Always MEAN | Weighted averaging based on chunk token counts |
| **Performance** | Optimal | All chunks processed for complete semantic coverage | | **Performance** | Optimal | All chunks processed for complete semantic coverage |
...@@ -69,7 +69,7 @@ Chunked processing uses **MEAN aggregation** for cross-chunk combination when in ...@@ -69,7 +69,7 @@ Chunked processing uses **MEAN aggregation** for cross-chunk combination when in
### Environment Variables ### Environment Variables
| Variable | Default | Description | | Variable | Default | Description |
|----------|---------|-------------| | -------- | ------- | ----------- |
| `MODEL_NAME` | `intfloat/multilingual-e5-large` | Embedding model to use (supports multiple models) | | `MODEL_NAME` | `intfloat/multilingual-e5-large` | Embedding model to use (supports multiple models) |
| `PORT` | `31090` | Server port | | `PORT` | `31090` | Server port |
| `GPU_COUNT` | `1` | Number of GPUs to use | | `GPU_COUNT` | `1` | Number of GPUs to use |
...@@ -106,7 +106,7 @@ With `MAX_EMBED_LEN=3072000`, you can process: ...@@ -106,7 +106,7 @@ With `MAX_EMBED_LEN=3072000`, you can process:
### Chunked Processing Performance ### Chunked Processing Performance
| Aspect | Behavior | Performance | | Aspect | Behavior | Performance |
|--------|----------|-------------| | ------ | -------- | ----------- |
| **Chunk Processing** | All chunks processed with native pooling | Consistent with input length | | **Chunk Processing** | All chunks processed with native pooling | Consistent with input length |
| **Cross-chunk Aggregation** | MEAN weighted averaging | Minimal overhead | | **Cross-chunk Aggregation** | MEAN weighted averaging | Minimal overhead |
| **Memory Usage** | Proportional to number of chunks | Moderate, scalable | | **Memory Usage** | Proportional to number of chunks | Moderate, scalable |
......
...@@ -103,7 +103,7 @@ vllm serve "$MODEL_NAME" \ ...@@ -103,7 +103,7 @@ vllm serve "$MODEL_NAME" \
--tensor-parallel-size "$GPU_COUNT" \ --tensor-parallel-size "$GPU_COUNT" \
--enforce-eager \ --enforce-eager \
--pooler-config "$POOLER_CONFIG" \ --pooler-config "$POOLER_CONFIG" \
--served-model-name ${MODEL_CODE} \ --served-model-name "${MODEL_CODE}" \
--api-key "$API_KEY" \ --api-key "$API_KEY" \
--trust-remote-code \ --trust-remote-code \
--port "$PORT" \ --port "$PORT" \
......
{%- if messages | length > 1 -%}
{{ raise_exception('Embedding models should only embed one message at a time') }}
{%- endif -%}
{% set vars = namespace(prefix='', images=[], texts=[]) %}
{%- for message in messages -%}
{%- if message['role'] == 'query' -%}
{%- set vars.prefix = 'query: ' %}
{%- elif message['role'] == 'document' -%}
{%- set vars.prefix = 'passage: ' %}
{%- endif -%}
{%- for content in message['content'] -%}
{%- if content['type'] == 'text' -%}
{%- set vars.texts = vars.texts + [content['text']] %}
{%- elif content['type'] == 'image' -%}
{%- set vars.images = vars.images + ['<image> '] %}
{%- endif -%}
{%- endfor -%}
{%- endfor -%}
{{- bos_token }}{{ vars.prefix }}{{ (vars.images + vars.texts) | join('') }}
...@@ -20,15 +20,17 @@ def main(): ...@@ -20,15 +20,17 @@ def main():
torch.set_default_dtype(torch.float16) torch.set_default_dtype(torch.float16)
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
img_prompt = dict( img_data = dict(
data=image_url, data=image_url,
data_format="url", data_format="url",
image_format="tiff", image_format="tiff",
out_data_format="b64_json", out_data_format="b64_json",
) )
prompt = dict(data=img_data)
llm = LLM( llm = LLM(
model="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", model="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
skip_tokenizer_init=True, skip_tokenizer_init=True,
trust_remote_code=True, trust_remote_code=True,
enforce_eager=True, enforce_eager=True,
...@@ -41,7 +43,7 @@ def main(): ...@@ -41,7 +43,7 @@ def main():
enable_mm_embeds=True, enable_mm_embeds=True,
) )
pooler_output = llm.encode(img_prompt, pooling_task="plugin") pooler_output = llm.encode(prompt, pooling_task="plugin")
output = pooler_output[0].outputs output = pooler_output[0].outputs
print(output) print(output)
......
...@@ -391,7 +391,7 @@ if __name__ == "__main__": ...@@ -391,7 +391,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--model", "--model",
type=str, type=str,
default="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", default="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
help="Path to a checkpoint file to load from.", help="Path to a checkpoint file to load from.",
) )
parser.add_argument( parser.add_argument(
......
...@@ -14,9 +14,7 @@ import requests ...@@ -14,9 +14,7 @@ import requests
# - install TerraTorch v1.1 (or later): # - install TerraTorch v1.1 (or later):
# pip install terratorch>=v1.1 # pip install terratorch>=v1.1
# - start vllm in serving mode with the below args # - start vllm in serving mode with the below args
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM' # --model='ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11'
# --model-impl terratorch
# --trust-remote-code
# --skip-tokenizer-init --enforce-eager # --skip-tokenizer-init --enforce-eager
# --io-processor-plugin terratorch_segmentation # --io-processor-plugin terratorch_segmentation
# --enable-mm-embeds # --enable-mm-embeds
...@@ -34,7 +32,7 @@ def main(): ...@@ -34,7 +32,7 @@ def main():
"out_data_format": "b64_json", "out_data_format": "b64_json",
}, },
"priority": 0, "priority": 0,
"model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", "model": "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
} }
ret = requests.post(server_endpoint, json=request_payload_url) ret = requests.post(server_endpoint, json=request_payload_url)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" """
Example of using ColBERT late interaction model for reranking. Example of using ColBERT late interaction models for reranking and scoring.
ColBERT (Contextualized Late Interaction over BERT) uses per-token embeddings ColBERT (Contextualized Late Interaction over BERT) uses per-token embeddings
and MaxSim scoring for document reranking, providing better accuracy than and MaxSim scoring for document reranking, providing better accuracy than
single-vector models while being more efficient than cross-encoders. single-vector models while being more efficient than cross-encoders.
Start the server with: vLLM supports ColBERT with multiple encoder backbones. Start the server
with one of the following:
# BERT backbone (works out of the box)
vllm serve answerdotai/answerai-colbert-small-v1 vllm serve answerdotai/answerai-colbert-small-v1
# ModernBERT backbone
vllm serve lightonai/GTE-ModernColBERT-v1 \
--hf-overrides '{"architectures": ["ColBERTModernBertModel"]}'
# Jina XLM-RoBERTa backbone
vllm serve jinaai/jina-colbert-v2 \
--hf-overrides '{"architectures": ["ColBERTJinaRobertaModel"]}' \
--trust-remote-code
Then run this script: Then run this script:
python colbert_rerank_online.py python colbert_rerank_online.py
""" """
...@@ -18,39 +30,62 @@ import json ...@@ -18,39 +30,62 @@ import json
import requests import requests
url = "http://127.0.0.1:8000/rerank" # Change this to match the model you started the server with
MODEL = "answerdotai/answerai-colbert-small-v1"
BASE_URL = "http://127.0.0.1:8000"
headers = {"accept": "application/json", "Content-Type": "application/json"} headers = {"accept": "application/json", "Content-Type": "application/json"}
data = { documents = [
"model": "answerdotai/answerai-colbert-small-v1", "Machine learning is a subset of artificial intelligence.",
"query": "What is machine learning?", "Python is a programming language.",
"documents": [ "Deep learning uses neural networks for complex tasks.",
"Machine learning is a subset of artificial intelligence.", "The weather today is sunny.",
"Python is a programming language.", ]
"Deep learning uses neural networks for complex tasks.",
"The weather today is sunny.",
], def rerank_example():
} """Use the /rerank endpoint to rank documents by query relevance."""
print("=== Rerank Example ===")
data = {
"model": MODEL,
"query": "What is machine learning?",
"documents": documents,
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
result = response.json()
print(json.dumps(result, indent=2))
print("\nRanked documents (most relevant first):")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" Score {score:.4f}: {documents[doc_idx]}")
def score_example():
"""Use the /score endpoint for pairwise query-document scoring."""
print("\n=== Score Example ===")
data = {
"model": MODEL,
"text_1": "What is machine learning?",
"text_2": [
"Machine learning is a subset of AI.",
"The weather is sunny.",
],
}
response = requests.post(f"{BASE_URL}/score", headers=headers, json=data)
result = response.json()
print(json.dumps(result, indent=2))
def main(): def main():
response = requests.post(url, headers=headers, json=data) rerank_example()
score_example()
if response.status_code == 200:
print("ColBERT Rerank Request successful!")
result = response.json()
print(json.dumps(result, indent=2))
# Show ranked results
print("\nRanked documents (most relevant first):")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" Score {score:.4f}: {data['documents'][doc_idx]}")
else:
print(f"Request failed with status code: {response.status_code}")
print(response.text)
if __name__ == "__main__": if __name__ == "__main__":
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example of using ColModernVBERT late interaction model for reranking.
ColModernVBERT is a multi-modal ColBERT-style model combining a SigLIP
vision encoder with a ModernBERT text encoder. It produces per-token
embeddings and uses MaxSim scoring for retrieval and reranking.
Supports both text and image inputs.
Start the server with:
vllm serve ModernVBERT/colmodernvbert-merged --max-model-len 8192
Then run this script:
python colmodernvbert_rerank_online.py
"""
import requests
MODEL = "ModernVBERT/colmodernvbert-merged"
BASE_URL = "http://127.0.0.1:8000"
headers = {"accept": "application/json", "Content-Type": "application/json"}
IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/PNG_transparency_demonstration_1.png/300px-PNG_transparency_demonstration_1.png" # noqa: E501
def rerank_text():
"""Text-only reranking via /rerank endpoint."""
print("=" * 60)
print("1. Text reranking (/rerank)")
print("=" * 60)
data = {
"model": MODEL,
"query": "What is machine learning?",
"documents": [
"Machine learning is a subset of artificial intelligence.",
"Python is a programming language.",
"Deep learning uses neural networks for complex tasks.",
"The weather today is sunny.",
],
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print("\n Ranked documents (most relevant first):")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" [{score:.4f}] {data['documents'][doc_idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def score_text():
"""Text-only scoring via /score endpoint."""
print()
print("=" * 60)
print("2. Text scoring (/score)")
print("=" * 60)
query = "What is the capital of France?"
documents = [
"The capital of France is Paris.",
"Berlin is the capital of Germany.",
"Python is a programming language.",
]
data = {
"model": MODEL,
"text_1": query,
"text_2": documents,
}
response = requests.post(f"{BASE_URL}/score", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print(f"\n Query: {query}\n")
for item in result["data"]:
idx = item["index"]
score = item["score"]
print(f" Doc {idx} (score={score:.4f}): {documents[idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def score_text_top_n():
"""Text reranking with top_n filtering via /rerank endpoint."""
print()
print("=" * 60)
print("3. Text reranking with top_n=2 (/rerank)")
print("=" * 60)
data = {
"model": MODEL,
"query": "What is the capital of France?",
"documents": [
"The capital of France is Paris.",
"Berlin is the capital of Germany.",
"Python is a programming language.",
"The Eiffel Tower is in Paris.",
],
"top_n": 2,
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print(f"\n Top {data['top_n']} results:")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" [{score:.4f}] {data['documents'][doc_idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def rerank_multimodal():
"""Multimodal reranking with text and image documents via /rerank."""
print()
print("=" * 60)
print("4. Multimodal reranking: text query vs image document (/rerank)")
print("=" * 60)
data = {
"model": MODEL,
"query": "A colorful logo with transparency",
"documents": [
{"content": [{"type": "image_url", "image_url": {"url": IMAGE_URL}}]},
"Python is a programming language.",
"The weather today is sunny.",
],
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print("\n Ranked documents (most relevant first):")
labels = ["[image]", "Python doc", "Weather doc"]
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" [{score:.4f}] {labels[doc_idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def main():
rerank_text()
score_text()
score_text_top_n()
rerank_multimodal()
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""
Example of using ColQwen3 late interaction model for reranking and scoring.
ColQwen3 is a multi-modal ColBERT-style model based on Qwen3-VL.
It produces per-token embeddings and uses MaxSim scoring for retrieval
and reranking. Supports both text and image inputs.
Start the server with:
vllm serve TomoroAI/tomoro-colqwen3-embed-4b --max-model-len 50000
Then run this script:
python colqwen3_rerank_online.py
"""
import base64
from io import BytesIO
import requests
from PIL import Image
MODEL = "TomoroAI/tomoro-colqwen3-embed-4b"
BASE_URL = "http://127.0.0.1:8000"
headers = {"accept": "application/json", "Content-Type": "application/json"}
# ── Image helpers ──────────────────────────────────────────
def load_image(url: str) -> Image.Image:
"""Download an image from URL (handles Wikimedia 403)."""
for hdrs in (
{},
{"User-Agent": "Mozilla/5.0 (compatible; ColQwen3-demo/1.0)"},
):
resp = requests.get(url, headers=hdrs, timeout=15)
if resp.status_code == 403:
continue
resp.raise_for_status()
return Image.open(BytesIO(resp.content)).convert("RGB")
raise RuntimeError(f"Could not fetch image from {url}")
def encode_image_base64(image: Image.Image) -> str:
"""Encode a PIL image to a base64 data URI."""
buf = BytesIO()
image.save(buf, format="PNG")
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()
def make_image_content(image_url: str, text: str = "Describe the image.") -> dict:
"""Build a ScoreMultiModalParam dict from an image URL."""
image = load_image(image_url)
return {
"content": [
{
"type": "image_url",
"image_url": {"url": encode_image_base64(image)},
},
{"type": "text", "text": text},
]
}
# ── Sample image URLs ─────────────────────────────────────
IMAGE_URLS = {
"beijing": "https://upload.wikimedia.org/wikipedia/commons/6/61/Beijing_skyline_at_night.JPG",
"london": "https://upload.wikimedia.org/wikipedia/commons/4/49/London_skyline.jpg",
"singapore": "https://upload.wikimedia.org/wikipedia/commons/2/27/Singapore_skyline_2022.jpg",
}
# ── Text-only examples ────────────────────────────────────
def rerank_text():
"""Text-only reranking via /rerank endpoint."""
print("=" * 60)
print("1. Text reranking (/rerank)")
print("=" * 60)
data = {
"model": MODEL,
"query": "What is machine learning?",
"documents": [
"Machine learning is a subset of artificial intelligence.",
"Python is a programming language.",
"Deep learning uses neural networks for complex tasks.",
"The weather today is sunny.",
],
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print("\n Ranked documents (most relevant first):")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" [{score:.4f}] {data['documents'][doc_idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def score_text():
"""Text-only scoring via /score endpoint."""
print()
print("=" * 60)
print("2. Text scoring (/score)")
print("=" * 60)
query = "What is the capital of France?"
documents = [
"The capital of France is Paris.",
"Berlin is the capital of Germany.",
"Python is a programming language.",
]
data = {
"model": MODEL,
"text_1": query,
"text_2": documents,
}
response = requests.post(f"{BASE_URL}/score", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print(f"\n Query: {query}\n")
for item in result["data"]:
idx = item["index"]
score = item["score"]
print(f" Doc {idx} (score={score:.4f}): {documents[idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def score_text_top_n():
"""Text reranking with top_n filtering via /rerank endpoint."""
print()
print("=" * 60)
print("3. Text reranking with top_n=2 (/rerank)")
print("=" * 60)
data = {
"model": MODEL,
"query": "What is the capital of France?",
"documents": [
"The capital of France is Paris.",
"Berlin is the capital of Germany.",
"Python is a programming language.",
"The Eiffel Tower is in Paris.",
],
"top_n": 2,
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print(f"\n Top {data['top_n']} results:")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" [{score:.4f}] {data['documents'][doc_idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
# ── Multi-modal examples (text query × image documents) ──
def score_text_vs_images():
"""Score a text query against image documents via /score."""
print()
print("=" * 60)
print("4. Multi-modal scoring: text query vs image docs (/score)")
print("=" * 60)
query = "Retrieve the city of Beijing"
labels = list(IMAGE_URLS.keys())
print(f"\n Loading {len(labels)} images...")
image_contents = [make_image_content(IMAGE_URLS[name]) for name in labels]
data = {
"model": MODEL,
"data_1": query,
"data_2": image_contents,
}
response = requests.post(f"{BASE_URL}/score", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print(f'\n Query: "{query}"\n')
for item in result["data"]:
idx = item["index"]
print(f" Doc {idx} [{labels[idx]}] score={item['score']:.4f}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
def rerank_text_vs_images():
"""Rerank image documents by a text query via /rerank."""
print()
print("=" * 60)
print("5. Multi-modal reranking: text query vs image docs (/rerank)")
print("=" * 60)
query = "Retrieve the city of London"
labels = list(IMAGE_URLS.keys())
print(f"\n Loading {len(labels)} images...")
image_contents = [make_image_content(IMAGE_URLS[name]) for name in labels]
data = {
"model": MODEL,
"query": query,
"documents": image_contents,
"top_n": 2,
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
if response.status_code == 200:
result = response.json()
print(f'\n Query: "{query}"')
print(f" Top {data['top_n']} results:\n")
for item in result["results"]:
idx = item["index"]
print(f" [{item['relevance_score']:.4f}] {labels[idx]}")
else:
print(f" Request failed: {response.status_code}")
print(f" {response.text[:300]}")
# ── Main ──────────────────────────────────────────────────
def main():
# Text-only
rerank_text()
score_text()
score_text_top_n()
# Multi-modal (text query × image documents)
score_text_vs_images()
rerank_text_vs_images()
if __name__ == "__main__":
main()
{%- set query_msg = (messages | selectattr('role', 'equalto', 'query') | list | first) -%}
{%- set doc_msg = (messages | selectattr('role', 'equalto', 'document') | list | first) -%}
{%- set q = query_msg['content'] -%}
{%- set d = doc_msg['content'] -%}
{# If the doc contains <image> anywhere, hoist a single <image> to the front #}
{%- set has_image = ("<image>" in d) -%}
{%- set d_clean = d | replace("<image>", "") -%}
{%- set q_clean = q | replace("<image>", "") -%}
{%- if has_image -%}<image>{{ " " }}{%- endif -%}
question:{{ q_clean }}{{ " " }}
{{ " " }}
{{ " " }}passage:{{ d_clean }}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment