Commit dcb5624a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-dev

parents 55880ca2 ba41cc90
...@@ -228,7 +228,7 @@ First, launch the OpenAI-compatible server: ...@@ -228,7 +228,7 @@ First, launch the OpenAI-compatible server:
```bash ```bash
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}'
``` ```
Then, you can use the OpenAI client as follows: Then, you can use the OpenAI client as follows:
......
...@@ -28,6 +28,8 @@ Please refer to the above pages for more details about each API. ...@@ -28,6 +28,8 @@ Please refer to the above pages for more details about each API.
[API Reference](/api/offline_inference/index) [API Reference](/api/offline_inference/index)
::: :::
(configuration-options)=
## Configuration Options ## Configuration Options
This section lists the most common options for running the vLLM engine. This section lists the most common options for running the vLLM engine.
...@@ -59,6 +61,8 @@ model = LLM( ...@@ -59,6 +61,8 @@ model = LLM(
Our [list of supported models](#supported-models) shows the model architectures that are recognized by vLLM. Our [list of supported models](#supported-models) shows the model architectures that are recognized by vLLM.
(reducing-memory-usage)=
### Reducing memory usage ### Reducing memory usage
Large models might cause your machine to run out of memory (OOM). Here are some options that help alleviate this problem. Large models might cause your machine to run out of memory (OOM). Here are some options that help alleviate this problem.
...@@ -81,6 +85,12 @@ before initializing vLLM. Otherwise, you may run into an error like `RuntimeErro ...@@ -81,6 +85,12 @@ before initializing vLLM. Otherwise, you may run into an error like `RuntimeErro
To control which devices are used, please instead set the `CUDA_VISIBLE_DEVICES` environment variable. To control which devices are used, please instead set the `CUDA_VISIBLE_DEVICES` environment variable.
::: :::
:::{note}
With tensor parallelism enabled, each process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism).
You can convert the model checkpoint to a sharded checkpoint using <gh-file:examples/offline_inference/save_sharded_state.py>. The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
:::
#### Quantization #### Quantization
Quantized models take less memory at the cost of lower precision. Quantized models take less memory at the cost of lower precision.
...@@ -103,6 +113,39 @@ llm = LLM(model="adept/fuyu-8b", ...@@ -103,6 +113,39 @@ llm = LLM(model="adept/fuyu-8b",
max_num_seqs=2) max_num_seqs=2)
``` ```
#### Reduce CUDA Graphs
By default, we optimize model inference using CUDA graphs which take up extra memory in the GPU.
:::{important}
CUDA graph capture takes up more memory in V1 than in V0.
:::
You can adjust `compilation_config` to achieve a better balance between inference speed and memory usage:
```python
from vllm import LLM
from vllm.config import CompilationConfig, CompilationLevel
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
# By default, it goes up to max_num_seqs
cudagraph_capture_sizes=[1, 2, 4, 8, 16],
),
)
```
You can disable graph capturing completely via the `enforce_eager` flag:
```python
from vllm import LLM
llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
enforce_eager=True)
```
#### Adjust cache size #### Adjust cache size
If you run out of CPU RAM, try the following options: If you run out of CPU RAM, try the following options:
...@@ -110,16 +153,25 @@ If you run out of CPU RAM, try the following options: ...@@ -110,16 +153,25 @@ If you run out of CPU RAM, try the following options:
- (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB). - (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB).
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
#### Disable unused modalities #### Multi-modal input limits
You can disable unused modalities (except for text) by setting its limit to zero. You can allow a smaller number of multi-modal items per prompt to reduce the memory footprint of the model:
```python
from vllm import LLM
# Accept up to 3 images and 1 video per prompt
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
limit_mm_per_prompt={"image": 3, "video": 1})
```
You can go a step further and disable unused modalities completely by setting its limit to zero.
For example, if your application only accepts image input, there is no need to allocate any memory for videos. For example, if your application only accepts image input, there is no need to allocate any memory for videos.
```python ```python
from vllm import LLM from vllm import LLM
# Accept images but not videos # Accept any number of images but no videos
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
limit_mm_per_prompt={"video": 0}) limit_mm_per_prompt={"video": 0})
``` ```
...@@ -134,6 +186,29 @@ llm = LLM(model="google/gemma-3-27b-it", ...@@ -134,6 +186,29 @@ llm = LLM(model="google/gemma-3-27b-it",
limit_mm_per_prompt={"image": 0}) limit_mm_per_prompt={"image": 0})
``` ```
#### Multi-modal processor arguments
For certain models, you can adjust the multi-modal processor arguments to
reduce the size of the processed multi-modal inputs, which in turn saves memory.
Here are some examples:
```python
from vllm import LLM
# Available for Qwen2-VL series models
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
mm_processor_kwargs={
"max_pixels": 768 * 768, # Default is 1280 * 28 * 28
})
# Available for InternVL series models
llm = LLM(model="OpenGVLab/InternVL2-2B",
mm_processor_kwargs={
"max_dynamic_patch": 4, # Default is 12
})
```
### Performance optimization and tuning ### Performance optimization and tuning
You can potentially improve the performance of vLLM by finetuning various options. You can potentially improve the performance of vLLM by finetuning various options.
......
...@@ -33,11 +33,13 @@ print(completion.choices[0].message) ...@@ -33,11 +33,13 @@ print(completion.choices[0].message)
vLLM supports some parameters that are not supported by OpenAI, `top_k` for example. vLLM supports some parameters that are not supported by OpenAI, `top_k` for example.
You can pass these parameters to vLLM using the OpenAI client in the `extra_body` parameter of your requests, i.e. `extra_body={"top_k": 50}` for `top_k`. You can pass these parameters to vLLM using the OpenAI client in the `extra_body` parameter of your requests, i.e. `extra_body={"top_k": 50}` for `top_k`.
::: :::
:::{important} :::{important}
By default, the server applies `generation_config.json` from the Hugging Face model repository if it exists. This means the default values of certain sampling parameters can be overridden by those recommended by the model creator. By default, the server applies `generation_config.json` from the Hugging Face model repository if it exists. This means the default values of certain sampling parameters can be overridden by those recommended by the model creator.
To disable this behavior, please pass `--generation-config vllm` when launching the server. To disable this behavior, please pass `--generation-config vllm` when launching the server.
::: :::
## Supported APIs ## Supported APIs
We currently support the following OpenAI APIs: We currently support the following OpenAI APIs:
...@@ -172,6 +174,12 @@ print(completion._request_id) ...@@ -172,6 +174,12 @@ print(completion._request_id)
The `vllm serve` command is used to launch the OpenAI-compatible server. The `vllm serve` command is used to launch the OpenAI-compatible server.
:::{tip}
The vast majority of command-line arguments are based on those for offline inference.
See [here](configuration-options) for some common options.
:::
:::{argparse} :::{argparse}
:module: vllm.entrypoints.openai.cli_args :module: vllm.entrypoints.openai.cli_args
:func: create_parser_for_docs :func: create_parser_for_docs
...@@ -394,9 +402,26 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai ...@@ -394,9 +402,26 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai
To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`. To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`.
::: :::
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
<!-- TODO: api enforced limits + uploading audios --> <!-- TODO: api enforced limits + uploading audios -->
Code example: <gh-file:examples/online_serving/openai_transcription_client.py> #### Extra Parameters
The following [sampling parameters](#sampling-params) are supported.
:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-transcription-sampling-params
:end-before: end-transcription-sampling-params
:::
The following extra parameters are supported:
:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
:start-after: begin-transcription-extra-params
:end-before: end-transcription-extra-params
:::
(tokenizer-api)= (tokenizer-api)=
......
# LMCache Examples
This folder demonstrates how to use LMCache for disaggregated prefilling, CPU offloading and KV cache sharing.
## 1. Disaggregated Prefill in vLLM v1
This example demonstrates how to run LMCache with disaggregated prefill using NIXL on a single node.
### Prerequisites
- Install [LMCache](https://github.com/LMCache/LMCache). You can simply run `pip install lmcache`.
- Install [NIXL](https://github.com/ai-dynamo/nixl).
- At least 2 GPUs
- Valid Hugging Face token (HF_TOKEN) for Llama 3.1 8B Instruct.
### Usage
Run
`cd disagg_prefill_lmcache_v1`
to get into `disagg_prefill_lmcache_v1` folder, and then run
```bash
bash disagg_example_nixl.sh
```
to run disaggregated prefill and benchmark the performance.
### Components
#### Server Scripts
- `disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh` - Launches individual vLLM servers for prefill/decode, and also launches the proxy server.
- `disagg_prefill_lmcache_v1/disagg_proxy_server.py` - FastAPI proxy server that coordinates between prefiller and decoder
- `disagg_prefill_lmcache_v1/disagg_example_nixl.sh` - Main script to run the example
#### Configuration
- `disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml` - Configuration for prefiller server
- `disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml` - Configuration for decoder server
#### Log Files
The main script generates several log files:
- `prefiller.log` - Logs from the prefill server
- `decoder.log` - Logs from the decode server
- `proxy.log` - Logs from the proxy server
## 2. CPU Offload Examples
- `cpu_offload_lmcache_v0.py` - CPU offloading implementation for vLLM v0
- `cpu_offload_lmcache_v1.py` - CPU offloading implementation for vLLM v1
## 3. KV Cache Sharing
The `kv_cache_sharing_lmcache_v1.py` example demonstrates how to share KV caches between vLLM v1 instances.
## 4. Disaggregated Prefill in vLLM v0
The `disaggregated_prefill_lmcache_v0.py` provides an example of how to run disaggregated prefill in vLLM v0.
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the example usage of cpu offloading
with LMCache.
Note that `lmcache` is needed to run this example.
Requirements: Linux, Python: 3.10 or higher, CUDA: 12.1
Learn more about LMCache environment setup, please refer to:
https://docs.lmcache.ai/getting_started/installation.html
"""
import contextlib
import os
import time
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
def setup_environment_variables():
# LMCache-related environment variables
# Use experimental features in LMCache
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
# LMCache is set to use 256 tokens per chunk
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
# Enable local CPU backend in LMCache
os.environ["LMCACHE_LOCAL_CPU"] = "True"
# Set local CPU memory limit to 5.0 GB
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
@contextlib.contextmanager
def build_llm_with_lmcache():
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}')
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
# Note: LMCache supports chunked prefill (see vLLM#14505, LMCache#392).
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
kv_transfer_config=ktc,
max_model_len=8000,
enable_chunked_prefill=True,
gpu_memory_utilization=0.8)
try:
yield llm
finally:
# Clean up lmcache backend
LMCacheEngineBuilder.destroy(ENGINE_NAME)
def print_output(
llm: LLM,
prompt: list[str],
sampling_params: SamplingParams,
req_str: str,
):
start = time.time()
outputs = llm.generate(prompt, sampling_params)
print("-" * 50)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
print(f"Generation took {time.time() - start:.2f} seconds, "
f"{req_str} request done.")
print("-" * 50)
def main():
setup_environment_variables()
with build_llm_with_lmcache() as llm:
# This example script runs two requests with a shared prefix.
# Define the shared prompt and specific prompts
shared_prompt = "Hello, how are you?" * 1000
first_prompt = [
shared_prompt + "Hello, my name is",
]
second_prompt = [
shared_prompt + "Tell me a very long story",
]
sampling_params = SamplingParams(temperature=0,
top_p=0.95,
max_tokens=10)
# Print the first output
print_output(llm, first_prompt, sampling_params, "first")
time.sleep(1)
# print the second output
print_output(llm, second_prompt, sampling_params, "second")
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
This file demonstrates the example usage of cpu offloading This file demonstrates the example usage of cpu offloading
with LMCache. with LMCache in vLLM v1.
Note that `pip install lmcache` is needed to run this example. Note that lmcache needs to be installed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache. Learn more about LMCache in https://github.com/LMCache/LMCache.
""" """
import os import os
import time
from lmcache.experimental.cache_engine import LMCacheEngineBuilder from lmcache.experimental.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME from lmcache.integration.vllm.utils import ENGINE_NAME
...@@ -37,29 +36,22 @@ second_prompt = [ ...@@ -37,29 +36,22 @@ second_prompt = [
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig.from_cli( ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}') '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}')
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory. # memory. Reduce the value if your GPU has less memory.
# Note that LMCache is not compatible with chunked prefill for now. # Note that LMCache is not compatible with chunked prefill for now.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct",
kv_transfer_config=ktc, kv_transfer_config=ktc,
max_model_len=8000, max_model_len=8000,
enable_chunked_prefill=False,
gpu_memory_utilization=0.8) gpu_memory_utilization=0.8)
# Should be able to see logs like the following:
# `LMCache INFO: Storing KV cache for 6006 out of 6006 tokens for request 0`
# This indicates that the KV cache has been stored in LMCache.
outputs = llm.generate(first_prompt, sampling_params) outputs = llm.generate(first_prompt, sampling_params)
for output in outputs: for output in outputs:
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}") print(f"Generated text: {generated_text!r}")
print("First request done.")
time.sleep(1)
outputs = llm.generate(second_prompt, sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
print("Second request done.")
# Clean up lmcache backend # Clean up lmcache backend
LMCacheEngineBuilder.destroy(ENGINE_NAME) LMCacheEngineBuilder.destroy(ENGINE_NAME)
...@@ -38,6 +38,10 @@ os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}" ...@@ -38,6 +38,10 @@ os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}"
# `naive` indicates using raw bytes of the tensor without any compression # `naive` indicates using raw bytes of the tensor without any compression
os.environ["LMCACHE_REMOTE_SERDE"] = "naive" os.environ["LMCACHE_REMOTE_SERDE"] = "naive"
prompts = [
"Hello, how are you?" * 1000,
]
def run_prefill(prefill_done, prompts): def run_prefill(prefill_done, prompts):
# We use GPU 0 for prefill node. # We use GPU 0 for prefill node.
...@@ -106,12 +110,7 @@ def run_lmcache_server(port): ...@@ -106,12 +110,7 @@ def run_lmcache_server(port):
return server_proc return server_proc
if __name__ == "__main__": def main():
prompts = [
"Hello, how are you?" * 1000,
]
prefill_done = Event() prefill_done = Event()
prefill_process = Process(target=run_prefill, args=(prefill_done, prompts)) prefill_process = Process(target=run_prefill, args=(prefill_done, prompts))
decode_process = Process(target=run_decode, args=(prefill_done, prompts)) decode_process = Process(target=run_decode, args=(prefill_done, prompts))
...@@ -128,3 +127,7 @@ if __name__ == "__main__": ...@@ -128,3 +127,7 @@ if __name__ == "__main__":
prefill_process.terminate() prefill_process.terminate()
lmcache_server_process.terminate() lmcache_server_process.terminate()
lmcache_server_process.wait() lmcache_server_process.wait()
if __name__ == "__main__":
main()
local_cpu: False
max_local_cpu_size: 0
#local_disk:
max_local_disk_size: 0
remote_serde: NULL
enable_nixl: True
nixl_role: "receiver"
nixl_peer_host: "localhost"
nixl_peer_port: 55555
nixl_buffer_size: 1073741824 # 1GB
nixl_buffer_device: "cuda"
nixl_enable_gc: True
local_cpu: False
max_local_cpu_size: 0
#local_disk:
max_local_disk_size: 0
remote_serde: NULL
enable_nixl: True
nixl_role: "sender"
nixl_peer_host: "localhost"
nixl_peer_port: 55555
nixl_buffer_size: 1073741824 # 1GB
nixl_buffer_device: "cuda"
nixl_enable_gc: True
#!/bin/bash
echo "Warning: LMCache disaggregated prefill support for vLLM v1 is experimental and subject to change."
PIDS=()
# Switch to the directory of the current script
cd "$(dirname "${BASH_SOURCE[0]}")"
check_hf_token() {
if [ -z "$HF_TOKEN" ]; then
echo "HF_TOKEN is not set. Please set it to your Hugging Face token."
exit 1
fi
if [[ "$HF_TOKEN" != hf_* ]]; then
echo "HF_TOKEN is not a valid Hugging Face token. Please set it to your Hugging Face token."
exit 1
fi
echo "HF_TOKEN is set and valid."
}
check_num_gpus() {
# can you check if the number of GPUs are >=2 via nvidia-smi?
num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
if [ "$num_gpus" -lt 2 ]; then
echo "You need at least 2 GPUs to run disaggregated prefill."
exit 1
else
echo "Found $num_gpus GPUs."
fi
}
ensure_python_library_installed() {
echo "Checking if $1 is installed..."
python -c "import $1" > /dev/null 2>&1
if [ $? -ne 0 ]; then
if [ "$1" == "nixl" ]; then
echo "$1 is not installed. Please refer to https://github.com/ai-dynamo/nixl for installation."
else
echo "$1 is not installed. Please install it via pip install $1."
fi
exit 1
else
echo "$1 is installed."
fi
}
cleanup() {
echo "Stopping everything…"
trap - INT TERM # prevent re-entrancy
kill -- -$$ # negative PID == “this whole process-group”
wait # reap children so we don't leave zombies
exit 0
}
wait_for_server() {
local port=$1
local timeout_seconds=1200
local start_time=$(date +%s)
echo "Waiting for server on port $port..."
while true; do
if curl -s "localhost:${port}/v1/completions" > /dev/null; then
return 0
fi
local now=$(date +%s)
if (( now - start_time >= timeout_seconds )); then
echo "Timeout waiting for server"
return 1
fi
sleep 1
done
}
main() {
check_hf_token
check_num_gpus
ensure_python_library_installed lmcache
ensure_python_library_installed nixl
ensure_python_library_installed pandas
ensure_python_library_installed datasets
ensure_python_library_installed vllm
trap cleanup INT
trap cleanup USR1
trap cleanup TERM
echo "Launching prefiller, decoder and proxy..."
echo "Please check prefiller.log, decoder.log and proxy.log for logs."
bash disagg_vllm_launcher.sh prefiller \
> >(tee prefiller.log) 2>&1 &
prefiller_pid=$!
PIDS+=($prefiller_pid)
bash disagg_vllm_launcher.sh decoder \
> >(tee decoder.log) 2>&1 &
decoder_pid=$!
PIDS+=($decoder_pid)
python3 disagg_proxy_server.py \
--host localhost \
--port 9000 \
--prefiller-host localhost \
--prefiller-port 8100 \
--decoder-host localhost \
--decoder-port 8200 \
> >(tee proxy.log) 2>&1 &
proxy_pid=$!
PIDS+=($proxy_pid)
wait_for_server 8100
wait_for_server 8200
wait_for_server 9000
echo "All servers are up. Starting benchmark..."
# begin benchmark
cd ../../../benchmarks/
python benchmark_serving.py --port 9000 --seed $(date +%s) \
--model meta-llama/Llama-3.1-8B-Instruct \
--dataset-name random --random-input-len 7500 --random-output-len 200 \
--num-prompts 200 --burstiness 100 --request-rate 3.6 | tee benchmark.log
echo "Benchmarking done. Cleaning up..."
cleanup
}
main
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
import time
from contextlib import asynccontextmanager
import httpx
import numpy as np
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Lifespan context manager to handle startup and shutdown events.
"""
# Startup: Initialize clients
prefiller_base_url = f'http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1'
decoder_base_url = f'http://{global_args.decoder_host}:{global_args.decoder_port}/v1'
app.state.prefill_client = httpx.AsyncClient(timeout=None,
base_url=prefiller_base_url)
app.state.decode_client = httpx.AsyncClient(timeout=None,
base_url=decoder_base_url)
yield
# Shutdown: Close clients
await app.state.prefill_client.aclose()
await app.state.decode_client.aclose()
# Update FastAPI app initialization to use lifespan
app = FastAPI(lifespan=lifespan)
class StatsCalculator:
def __init__(self):
self._stats = []
self._last_log_time = time.time()
def add(self, value):
self._stats.append(value)
if time.time() - self._last_log_time > 5:
self._log_stats()
self._last_log_time = time.time()
def _log_stats(self):
# Print average, median, and 99th percentile
np_arr = np.array(self._stats)
output_str = f"\nNum requests: {len(self._stats)}" + \
"\nPrefill node TTFT stats:" + \
f"\n - Average (ms): {np.mean(np_arr)}" + \
f"\n - Median (ms): {np.median(np_arr)}" + \
f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n"
print("===============================", output_str,
"===============================")
stats_calculator = StatsCalculator()
counter = 0
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--prefiller-host", type=str, default="localhost")
parser.add_argument("--prefiller-port", type=int, default=8100)
parser.add_argument("--decoder-host", type=str, default="localhost")
parser.add_argument("--decoder-port", type=int, default=8200)
args = parser.parse_args()
return args
# Initialize variables to hold the persistent clients
app.state.prefill_client = None
app.state.decode_client = None
async def send_request_to_service(client: httpx.AsyncClient, endpoint: str,
req_data: dict):
"""
Send a request to a service using a persistent client.
"""
req_data = req_data.copy()
req_data['max_tokens'] = 1
if 'max_completion_tokens' in req_data:
req_data['max_completion_tokens'] = 1
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
response = await client.post(endpoint, json=req_data, headers=headers)
response.raise_for_status()
return response
async def stream_service_response(client: httpx.AsyncClient, endpoint: str,
req_data: dict):
"""
Asynchronously stream the response from a service using a persistent client.
"""
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
async with client.stream("POST", endpoint, json=req_data,
headers=headers) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
@app.post("/v1/completions")
async def handle_completions(request: Request):
global counter, stats_calculator
counter += 1
st = time.time()
try:
req_data = await request.json()
# Send request to prefill service, ignore the response
await send_request_to_service(app.state.prefill_client, "/completions",
req_data)
et = time.time()
stats_calculator.add(et - st)
# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(app.state.decode_client,
"/completions",
req_data):
yield chunk
return StreamingResponse(generate_stream(),
media_type="application/json")
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
" - completions endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@app.post("/v1/chat/completions")
async def handle_chat_completions(request: Request):
global counter, stats_calculator
counter += 1
st = time.time()
try:
req_data = await request.json()
# Send request to prefill service, ignore the response
await send_request_to_service(app.state.prefill_client,
"/chat/completions", req_data)
et = time.time()
stats_calculator.add(et - st)
# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(app.state.decode_client,
"/chat/completions",
req_data):
yield chunk
return StreamingResponse(generate_stream(),
media_type="application/json")
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server "
" - chat completions endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
if __name__ == '__main__':
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)
#!/bin/bash
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
if [[ $# -lt 1 ]]; then
echo "Usage: $0 <prefiller | decoder> [model]"
exit 1
fi
if [[ $# -eq 1 ]]; then
echo "Using default model: meta-llama/Llama-3.1-8B-Instruct"
MODEL="meta-llama/Llama-3.1-8B-Instruct"
else
echo "Using model: $2"
MODEL=$2
fi
if [[ $1 == "prefiller" ]]; then
# Prefiller listens on port 8100
prefill_config_file=$SCRIPT_DIR/configs/lmcache-prefiller-config.yaml
UCX_TLS=cuda_ipc,cuda_copy,tcp \
LMCACHE_CONFIG_FILE=$prefill_config_file \
LMCACHE_USE_EXPERIMENTAL=True \
VLLM_ENABLE_V1_MULTIPROCESSING=1 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
CUDA_VISIBLE_DEVICES=0 \
vllm serve $MODEL \
--port 8100 \
--disable-log-requests \
--enforce-eager \
--kv-transfer-config \
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}'
elif [[ $1 == "decoder" ]]; then
# Decoder listens on port 8200
decode_config_file=$SCRIPT_DIR/configs/lmcache-decoder-config.yaml
UCX_TLS=cuda_ipc,cuda_copy,tcp \
LMCACHE_CONFIG_FILE=$decode_config_file \
LMCACHE_USE_EXPERIMENTAL=True \
VLLM_ENABLE_V1_MULTIPROCESSING=1 \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
CUDA_VISIBLE_DEVICES=1 \
vllm serve $MODEL \
--port 8200 \
--disable-log-requests \
--enforce-eager \
--kv-transfer-config \
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}'
else
echo "Invalid role: $1"
echo "Should be either prefill, decode"
exit 1
fi
# SPDX-License-Identifier: Apache-2.0
"""
This file demonstrates the example usage of remote KV cache sharing
with LMCache.
We will launch 2 vllm instances, and launch an additional LMCache server.
KV cache is transferred in the following manner:
(1) vLLM instance 1 -> LMCache server (KV cache store).
(2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve).
Note that lmcache needs to be installed to run this example.
Learn more about LMCache in https://github.com/LMCache/LMCache.
"""
import os
import subprocess
import time
from multiprocessing import Event, Process
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
from lmcache.integration.vllm.utils import ENGINE_NAME
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
# LMCache-related environment variables
# The port to start LMCache server
port = 8100
# Use experimental features in LMCache
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
# LMCache is set to use 256 tokens per chunk
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
# Disable local CPU backend in LMCache
os.environ["LMCACHE_LOCAL_CPU"] = "False"
# Set local CPU memory buffer limit to 5.0 GB
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
# Set the remote URL for LMCache server
os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}"
# Set the serializer/deserializer between vllm and LMCache server
# `naive` indicates using raw bytes of the tensor without any compression
os.environ["LMCACHE_REMOTE_SERDE"] = "naive"
prompts = [
"Hello, how are you?" * 1000,
]
def run_store(store_done, prompts):
# We use GPU 0 for KV cache store process.
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}')
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
enforce_eager=True)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
print("KV cache store is finished.")
store_done.set()
# Clean up lmcache backend
LMCacheEngineBuilder.destroy(ENGINE_NAME)
def run_retrieve(store_done, prompts, timeout=1):
# We use GPU 1 for KV cache retrieve process.
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}')
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# of memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
kv_transfer_config=ktc,
max_model_len=8000,
gpu_memory_utilization=0.8,
enforce_eager=True)
print("Waiting for KV cache store to finish...")
store_done.wait()
time.sleep(timeout)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
# Clean up lmcache backend
LMCacheEngineBuilder.destroy(ENGINE_NAME)
def run_lmcache_server(port):
server_proc = subprocess.Popen([
"python", "-m", "lmcache.experimental.server", "localhost",
str(port)
])
return server_proc
def main():
store_done = Event()
store_process = Process(target=run_store, args=(store_done, prompts))
retrieve_process = Process(target=run_retrieve, args=(store_done, prompts))
lmcache_server_process = run_lmcache_server(port)
# Start KV cache store process
store_process.start()
# Start KV cache retrieve process
retrieve_process.start()
# Clean up the processes
store_process.join()
retrieve_process.terminate()
lmcache_server_process.terminate()
lmcache_server_process.wait()
if __name__ == "__main__":
main()
...@@ -38,6 +38,37 @@ class ModelRequestData(NamedTuple): ...@@ -38,6 +38,37 @@ class ModelRequestData(NamedTuple):
# Unless specified, these settings have been tested to work on a single L4. # Unless specified, these settings have been tested to work on a single L4.
# Granite Speech
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
# NOTE - the setting in this example are somehat different than what is
# optimal for granite speech, and it is generally recommended to use beam
# search. Check the model README for suggested settings.
# https://huggingface.co/ibm-granite/granite-speech-3.3-8b
model_name = "ibm-granite/granite-speech-3.3-8b"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=2048,
max_num_seqs=2,
enable_lora=True,
max_lora_rank=64,
limit_mm_per_prompt={"audio": audio_count},
)
# The model has an audio-specific lora directly in its model dir;
# it should be enabled whenever you pass audio inputs to the model.
speech_lora_path = model_name
audio_placeholder = "<|audio|>" * audio_count
prompts = f"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>{audio_placeholder}{question}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501
return ModelRequestData(
engine_args=engine_args,
prompt=prompts,
lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
)
# MiniCPM-O # MiniCPM-O
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData: def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
model_name = "openbmb/MiniCPM-o-2_6" model_name = "openbmb/MiniCPM-o-2_6"
...@@ -89,7 +120,7 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData: ...@@ -89,7 +120,7 @@ def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_path, model=model_path,
trust_remote_code=True, trust_remote_code=True,
max_model_len=4096, max_model_len=12800,
max_num_seqs=2, max_num_seqs=2,
enable_lora=True, enable_lora=True,
max_lora_rank=320, max_lora_rank=320,
...@@ -130,6 +161,36 @@ def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData: ...@@ -130,6 +161,36 @@ def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
) )
# Qwen2.5-Omni
def run_qwen2_5_omni(question: str, audio_count: int):
model_name = "Qwen/Qwen2.5-Omni-7B"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
)
audio_in_prompt = "".join([
"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)
])
default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech.")
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n"
f"{audio_in_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n")
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Ultravox 0.5-1B # Ultravox 0.5-1B
def run_ultravox(question: str, audio_count: int) -> ModelRequestData: def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b" model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
...@@ -179,14 +240,43 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData: ...@@ -179,14 +240,43 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
model_example_map = { model_example_map = {
"granite_speech": run_granite_speech,
"minicpmo": run_minicpmo, "minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm, "phi4_mm": run_phi4mm,
"qwen2_audio": run_qwen2_audio, "qwen2_audio": run_qwen2_audio,
"qwen2_5_omni": run_qwen2_5_omni,
"ultravox": run_ultravox, "ultravox": run_ultravox,
"whisper": run_whisper, "whisper": run_whisper,
} }
def parse_args():
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'audio language models')
parser.add_argument('--model-type',
'-m',
type=str,
default="ultravox",
choices=model_example_map.keys(),
help='Huggingface "model_type".')
parser.add_argument('--num-prompts',
type=int,
default=1,
help='Number of prompts to run.')
parser.add_argument("--num-audios",
type=int,
default=1,
choices=[0, 1, 2],
help="Number of audio items per prompt.")
parser.add_argument("--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.")
return parser.parse_args()
def main(args): def main(args):
model = args.model_type model = args.model_type
if model not in model_example_map: if model not in model_example_map:
...@@ -240,28 +330,5 @@ def main(args): ...@@ -240,28 +330,5 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( args = parse_args()
description='Demo on using vLLM for offline inference with '
'audio language models')
parser.add_argument('--model-type',
'-m',
type=str,
default="ultravox",
choices=model_example_map.keys(),
help='Huggingface "model_type".')
parser.add_argument('--num-prompts',
type=int,
default=1,
help='Number of prompts to run.')
parser.add_argument("--num-audios",
type=int,
default=1,
choices=[0, 1, 2],
help="Number of audio items per prompt.")
parser.add_argument("--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.")
args = parser.parse_args()
main(args) main(args)
...@@ -2,20 +2,22 @@ ...@@ -2,20 +2,22 @@
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
if __name__ == '__main__': # Sample prompts.
# Sample prompts. prompts = [
prompts = [ "Hello, my name is",
"Hello, my name is", "The president of the United States is",
"The president of the United States is", "The capital of France is",
"The capital of France is", "The future of AI is",
"The future of AI is", ]
] # Create a sampling params object.
# Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
def main():
# Create an LLM. # Create an LLM.
llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, dtype="float16",trust_remote_code=True, enforce_eager=True) llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, dtype="float16",trust_remote_code=True, enforce_eager=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects # Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information. # that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
# Print the outputs. # Print the outputs.
...@@ -26,3 +28,8 @@ if __name__ == '__main__': ...@@ -26,3 +28,8 @@ if __name__ == '__main__':
print(f"Prompt: {prompt!r}") print(f"Prompt: {prompt!r}")
print(f"Output: {generated_text!r}") print(f"Output: {generated_text!r}")
print("-" * 60) print("-" * 60)
if __name__ == "__main__":
main()
...@@ -4,6 +4,24 @@ from vllm import LLM, EngineArgs ...@@ -4,6 +4,24 @@ from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
def create_parser():
parser = FlexibleArgumentParser()
# Add engine args
engine_group = parser.add_argument_group("Engine arguments")
EngineArgs.add_cli_args(engine_group)
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
# Add sampling params
sampling_group = parser.add_argument_group("Sampling parameters")
sampling_group.add_argument("--max-tokens", type=int)
sampling_group.add_argument("--temperature", type=float)
sampling_group.add_argument("--top-p", type=float)
sampling_group.add_argument("--top-k", type=int)
# Add example params
parser.add_argument("--chat-template-path", type=str)
return parser
def main(args: dict): def main(args: dict):
# Pop arguments not used by LLM # Pop arguments not used by LLM
max_tokens = args.pop("max_tokens") max_tokens = args.pop("max_tokens")
...@@ -82,18 +100,6 @@ def main(args: dict): ...@@ -82,18 +100,6 @@ def main(args: dict):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser() parser = create_parser()
# Add engine args
engine_group = parser.add_argument_group("Engine arguments")
EngineArgs.add_cli_args(engine_group)
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
# Add sampling params
sampling_group = parser.add_argument_group("Sampling parameters")
sampling_group.add_argument("--max-tokens", type=int)
sampling_group.add_argument("--temperature", type=float)
sampling_group.add_argument("--top-p", type=float)
sampling_group.add_argument("--top-k", type=int)
# Add example params
parser.add_argument("--chat-template-path", type=str)
args: dict = vars(parser.parse_args()) args: dict = vars(parser.parse_args())
main(args) main(args)
...@@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs ...@@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach",
task="classify",
enforce_eager=True)
return parser.parse_args()
def main(args: Namespace): def main(args: Namespace):
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
...@@ -34,11 +44,5 @@ def main(args: Namespace): ...@@ -34,11 +44,5 @@ def main(args: Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser() args = parse_args()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach",
task="classify",
enforce_eager=True)
args = parser.parse_args()
main(args) main(args)
...@@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs ...@@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(model="intfloat/e5-mistral-7b-instruct",
task="embed",
enforce_eager=True)
return parser.parse_args()
def main(args: Namespace): def main(args: Namespace):
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
...@@ -34,11 +44,5 @@ def main(args: Namespace): ...@@ -34,11 +44,5 @@ def main(args: Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser() args = parse_args()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(model="intfloat/e5-mistral-7b-instruct",
task="embed",
enforce_eager=True)
args = parser.parse_args()
main(args) main(args)
...@@ -4,6 +4,22 @@ from vllm import LLM, EngineArgs ...@@ -4,6 +4,22 @@ from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
def create_parser():
parser = FlexibleArgumentParser()
# Add engine args
engine_group = parser.add_argument_group("Engine arguments")
EngineArgs.add_cli_args(engine_group)
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
# Add sampling params
sampling_group = parser.add_argument_group("Sampling parameters")
sampling_group.add_argument("--max-tokens", type=int)
sampling_group.add_argument("--temperature", type=float)
sampling_group.add_argument("--top-p", type=float)
sampling_group.add_argument("--top-k", type=int)
return parser
def main(args: dict): def main(args: dict):
# Pop arguments not used by LLM # Pop arguments not used by LLM
max_tokens = args.pop("max_tokens") max_tokens = args.pop("max_tokens")
...@@ -35,23 +51,15 @@ def main(args: dict): ...@@ -35,23 +51,15 @@ def main(args: dict):
] ]
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
# Print the outputs. # Print the outputs.
print("-" * 50)
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser() parser = create_parser()
# Add engine args
engine_group = parser.add_argument_group("Engine arguments")
EngineArgs.add_cli_args(engine_group)
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
# Add sampling params
sampling_group = parser.add_argument_group("Sampling parameters")
sampling_group.add_argument("--max-tokens", type=int)
sampling_group.add_argument("--temperature", type=float)
sampling_group.add_argument("--top-p", type=float)
sampling_group.add_argument("--top-k", type=int)
args: dict = vars(parser.parse_args()) args: dict = vars(parser.parse_args())
main(args) main(args)
...@@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs ...@@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
def parse_args():
parser = FlexibleArgumentParser()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(model="BAAI/bge-reranker-v2-m3",
task="score",
enforce_eager=True)
return parser.parse_args()
def main(args: Namespace): def main(args: Namespace):
# Sample prompts. # Sample prompts.
text_1 = "What is the capital of France?" text_1 = "What is the capital of France?"
...@@ -30,11 +40,5 @@ def main(args: Namespace): ...@@ -30,11 +40,5 @@ def main(args: Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser() args = parse_args()
parser = EngineArgs.add_cli_args(parser)
# Set example specific arguments
parser.set_defaults(model="BAAI/bge-reranker-v2-m3",
task="score",
enforce_eager=True)
args = parser.parse_args()
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment