Commit 500b93c8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1

parents 99426767 38c4b7e8
"""Example Python client for vllm.entrypoints.api_server """Example Python client for `vllm.entrypoints.api_server`
NOTE: The API server is used only for demonstration and simple performance NOTE: The API server is used only for demonstration and simple performance
benchmarks. It is not intended for production use. benchmarks. It is not intended for production use.
For production use, we recommend vllm.entrypoints.openai.api_server For production use, we recommend `vllm serve` and the OpenAI client API.
and the OpenAI client API
""" """
import argparse import argparse
......
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="meta-llama/Llama-2-13b-chat-hf", cpu_offload_gb=10)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
import os
import subprocess
from PIL import Image
from vllm import LLM from vllm import LLM
from vllm.assets.image import ImageAsset
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
# You can use `.buildkite/download-images.sh` to download them
def run_llava(): def run_llava():
...@@ -14,7 +7,7 @@ def run_llava(): ...@@ -14,7 +7,7 @@ def run_llava():
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:" prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
image = Image.open("images/stop_sign.jpg") image = ImageAsset("stop_sign").pil_image
outputs = llm.generate({ outputs = llm.generate({
"prompt": prompt, "prompt": prompt,
...@@ -28,25 +21,5 @@ def run_llava(): ...@@ -28,25 +21,5 @@ def run_llava():
print(generated_text) print(generated_text)
def main():
run_llava()
if __name__ == "__main__": if __name__ == "__main__":
# Download from s3 run_llava()
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
local_directory = "images"
# Make sure the local directory exists or create it
os.makedirs(local_directory, exist_ok=True)
# Use AWS CLI to sync the directory, assume anonymous access
subprocess.check_call([
"aws",
"s3",
"sync",
s3_bucket_path,
local_directory,
"--no-sign-request",
])
main()
...@@ -95,9 +95,7 @@ to the path of the custom logging configuration JSON file: ...@@ -95,9 +95,7 @@ to the path of the custom logging configuration JSON file:
```bash ```bash
VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \ VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \
python3 -m vllm.entrypoints.openai.api_server \ vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048
--max-model-len 2048 \
--model mistralai/Mistral-7B-v0.1
``` ```
...@@ -152,9 +150,7 @@ to the path of the custom logging configuration JSON file: ...@@ -152,9 +150,7 @@ to the path of the custom logging configuration JSON file:
```bash ```bash
VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \ VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \
python3 -m vllm.entrypoints.openai.api_server \ vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048
--max-model-len 2048 \
--model mistralai/Mistral-7B-v0.1
``` ```
...@@ -167,9 +163,7 @@ loggers. ...@@ -167,9 +163,7 @@ loggers.
```bash ```bash
VLLM_CONFIGURE_LOGGING=0 \ VLLM_CONFIGURE_LOGGING=0 \
python3 -m vllm.entrypoints.openai.api_server \ vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048
--max-model-len 2048 \
--model mistralai/Mistral-7B-v0.1
``` ```
......
from vllm import LLM, SamplingParams
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
"The greatest glory in living lies not in never falling,",
]
answers = [
" or, through inaction, allow a human being to come to harm.",
" what is essential is invisible to the eye.",
" but in rising every time we fall.",
]
N = 1
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
sampling_params = SamplingParams(temperature=0.7,
top_p=1.0,
n=N,
max_tokens=16)
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`.
llm = LLM(model="google/gemma-2b", enforce_eager=True)
outputs = llm.generate(prompts, sampling_params)
for output, answer in zip(outputs, answers):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
assert generated_text.startswith(answer)
"""An example showing how to use vLLM to serve VLMs. """An example showing how to use vLLM to serve VLMs.
Launch the vLLM server with the following command: Launch the vLLM server with the following command:
python -m vllm.entrypoints.openai.api_server \ vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
--model llava-hf/llava-1.5-7b-hf \
--chat-template template_llava.jinja
""" """
import base64 import base64
......
import os
import subprocess
from PIL import Image
from vllm import LLM from vllm import LLM
from vllm.assets.image import ImageAsset
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
# You can use `.buildkite/download-images.sh` to download them
def run_paligemma(): def run_paligemma():
...@@ -14,7 +7,7 @@ def run_paligemma(): ...@@ -14,7 +7,7 @@ def run_paligemma():
prompt = "caption es" prompt = "caption es"
image = Image.open("images/stop_sign.jpg") image = ImageAsset("stop_sign").pil_image
outputs = llm.generate({ outputs = llm.generate({
"prompt": prompt, "prompt": prompt,
...@@ -28,25 +21,5 @@ def run_paligemma(): ...@@ -28,25 +21,5 @@ def run_paligemma():
print(generated_text) print(generated_text)
def main():
run_paligemma()
if __name__ == "__main__": if __name__ == "__main__":
# Download from s3 run_paligemma()
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
local_directory = "images"
# Make sure the local directory exists or create it
os.makedirs(local_directory, exist_ok=True)
# Use AWS CLI to sync the directory, assume anonymous access
subprocess.check_call([
"aws",
"s3",
"sync",
s3_bucket_path,
local_directory,
"--no-sign-request",
])
main()
import os
import subprocess
from PIL import Image
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
# You can use `.buildkite/download-images.sh` to download them
def run_phi3v(): def run_phi3v():
...@@ -24,7 +17,7 @@ def run_phi3v(): ...@@ -24,7 +17,7 @@ def run_phi3v():
max_num_seqs=5, max_num_seqs=5,
) )
image = Image.open("images/cherry_blossom.jpg") image = ImageAsset("cherry_blossom").pil_image
# single-image prompt # single-image prompt
prompt = "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n" # noqa: E501 prompt = "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n" # noqa: E501
...@@ -44,19 +37,4 @@ def run_phi3v(): ...@@ -44,19 +37,4 @@ def run_phi3v():
if __name__ == "__main__": if __name__ == "__main__":
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
local_directory = "images"
# Make sure the local directory exists or create it
os.makedirs(local_directory, exist_ok=True)
# Use AWS CLI to sync the directory, assume anonymous access
subprocess.check_call([
"aws",
"s3",
"sync",
s3_bucket_path,
local_directory,
"--no-sign-request",
])
run_phi3v() run_phi3v()
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
``` ```
export OTEL_SERVICE_NAME="vllm-server" export OTEL_SERVICE_NAME="vllm-server"
export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true export OTEL_EXPORTER_OTLP_TRACES_INSECURE=true
python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT" vllm serve facebook/opt-125m --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"
``` ```
1. In a new shell, send requests with trace context from a dummy client 1. In a new shell, send requests with trace context from a dummy client
...@@ -62,7 +62,7 @@ By default, `grpc` is used. To set `http/protobuf` as the protocol, configure th ...@@ -62,7 +62,7 @@ By default, `grpc` is used. To set `http/protobuf` as the protocol, configure th
``` ```
export OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=http/protobuf export OTEL_EXPORTER_OTLP_TRACES_PROTOCOL=http/protobuf
export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=http://$JAEGER_IP:4318/v1/traces export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=http://$JAEGER_IP:4318/v1/traces
python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT" vllm serve facebook/opt-125m --otlp-traces-endpoint="$OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"
``` ```
## Instrumentation of FastAPI ## Instrumentation of FastAPI
...@@ -74,7 +74,7 @@ OpenTelemetry allows automatic instrumentation of FastAPI. ...@@ -74,7 +74,7 @@ OpenTelemetry allows automatic instrumentation of FastAPI.
1. Run vLLM with `opentelemetry-instrument` 1. Run vLLM with `opentelemetry-instrument`
``` ```
opentelemetry-instrument python -m vllm.entrypoints.openai.api_server --model="facebook/opt-125m" opentelemetry-instrument vllm serve facebook/opt-125m
``` ```
1. Send a request to vLLM and find its trace in Jaeger. It should contain spans from FastAPI. 1. Send a request to vLLM and find its trace in Jaeger. It should contain spans from FastAPI.
......
...@@ -10,8 +10,7 @@ Install: ...@@ -10,8 +10,7 @@ Install:
Prometheus metric logging is enabled by default in the OpenAI-compatible server. Launch via the entrypoint: Prometheus metric logging is enabled by default in the OpenAI-compatible server. Launch via the entrypoint:
```bash ```bash
python3 -m vllm.entrypoints.openai.api_server \ vllm serve mistralai/Mistral-7B-v0.1 \
--model mistralai/Mistral-7B-v0.1 \
--max-model-len 2048 \ --max-model-len 2048 \
--disable-log-requests --disable-log-requests
``` ```
......
#!/bin/bash
# Check for minimum number of required arguments
if [ $# -lt 4 ]; then
echo "Usage: $0 docker_image head_node_address --head|--worker path_to_hf_home [additional_args...]"
exit 1
fi
# Assign the first three arguments and shift them away
DOCKER_IMAGE="$1"
HEAD_NODE_ADDRESS="$2"
NODE_TYPE="$3" # Should be --head or --worker
PATH_TO_HF_HOME="$4"
shift 4
# Additional arguments are passed directly to the Docker command
ADDITIONAL_ARGS="$@"
# Validate node type
if [ "${NODE_TYPE}" != "--head" ] && [ "${NODE_TYPE}" != "--worker" ]; then
echo "Error: Node type must be --head or --worker"
exit 1
fi
# Define a function to cleanup on EXIT signal
cleanup() {
docker stop node
docker rm node
}
trap cleanup EXIT
# Command setup for head or worker node
RAY_START_CMD="ray start --block"
if [ "${NODE_TYPE}" == "--head" ]; then
RAY_START_CMD+=" --head --port=6379"
else
RAY_START_CMD+=" --address=${HEAD_NODE_ADDRESS}:6379"
fi
# Run the docker command with the user specified parameters and additional arguments
docker run \
--entrypoint /bin/bash \
--network host \
--name node \
--shm-size 10.24g \
--gpus all \
-v "${PATH_TO_HF_HOME}:/root/.cache/huggingface" \
${ADDITIONAL_ARGS} \
"${DOCKER_IMAGE}" -c "${RAY_START_CMD}"
...@@ -2,5 +2,9 @@ ...@@ -2,5 +2,9 @@
-r requirements-common.txt -r requirements-common.txt
# Dependencies for AMD GPUs # Dependencies for AMD GPUs
awscli
boto3
botocore
ray >= 2.10.0 ray >= 2.10.0
peft
pytest-asyncio pytest-asyncio
--- amd_hip_bf16.h 2024-02-06 18:28:58.268699142 +0000
+++ amd_hip_bf16.h.new 2024-02-06 18:28:31.988647133 +0000
@@ -90,10 +90,10 @@
#include "math_fwd.h" // ocml device functions
#if defined(__HIPCC_RTC__)
-#define __HOST_DEVICE__ __device__
+#define __HOST_DEVICE__ __device__ static
#else
#include <climits>
-#define __HOST_DEVICE__ __host__ __device__
+#define __HOST_DEVICE__ __host__ __device__ static inline
#endif
// Since we are using unsigned short to represent data in bfloat16, it can be of different sizes on
import os import os
import pathlib import pathlib
from dataclasses import dataclass
import pytest import pytest
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath( chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
...@@ -50,24 +49,9 @@ TEST_MESSAGES = [ ...@@ -50,24 +49,9 @@ TEST_MESSAGES = [
] ]
@dataclass
class MockTokenizer:
chat_template = None
@dataclass
class MockServingChat:
tokenizer: MockTokenizer
def test_load_chat_template(): def test_load_chat_template():
# Testing chatml template # Testing chatml template
tokenizer = MockTokenizer() template_content = load_chat_template(chat_template=chatml_jinja_path)
mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=chatml_jinja_path)
template_content = tokenizer.chat_template
# Test assertions # Test assertions
assert template_content is not None assert template_content is not None
...@@ -79,24 +63,16 @@ def test_load_chat_template(): ...@@ -79,24 +63,16 @@ def test_load_chat_template():
def test_no_load_chat_template_filelike(): def test_no_load_chat_template_filelike():
# Testing chatml template # Testing chatml template
template = "../../examples/does_not_exist" template = "../../examples/does_not_exist"
tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer)
with pytest.raises(ValueError, match="looks like a file path"): with pytest.raises(ValueError, match="looks like a file path"):
OpenAIServingChat._load_chat_template(mock_serving_chat, load_chat_template(chat_template=template)
chat_template=template)
def test_no_load_chat_template_literallike(): def test_no_load_chat_template_literallike():
# Testing chatml template # Testing chatml template
template = "{{ messages }}" template = "{{ messages }}"
tokenizer = MockTokenizer()
mock_serving_chat = MockServingChat(tokenizer) template_content = load_chat_template(chat_template=template)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
template_content = tokenizer.chat_template
assert template_content == template assert template_content == template
...@@ -108,9 +84,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt, ...@@ -108,9 +84,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
expected_output): expected_output):
# Initialize the tokenizer # Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model) tokenizer = get_tokenizer(tokenizer_name=model)
mock_serving_chat = MockServingChat(tokenizer) template_content = load_chat_template(chat_template=template)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
# Create a mock request object using keyword arguments # Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest( mock_request = ChatCompletionRequest(
...@@ -122,7 +96,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt, ...@@ -122,7 +96,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
result = tokenizer.apply_chat_template( result = tokenizer.apply_chat_template(
conversation=mock_request.messages, conversation=mock_request.messages,
tokenize=False, tokenize=False,
add_generation_prompt=mock_request.add_generation_prompt) add_generation_prompt=mock_request.add_generation_prompt,
chat_template=mock_request.chat_template or template_content)
# Test assertion # Test assertion
assert result == expected_output, ( assert result == expected_output, (
......
...@@ -9,9 +9,7 @@ MODEL_NAME = "facebook/opt-125m" ...@@ -9,9 +9,7 @@ MODEL_NAME = "facebook/opt-125m"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
with RemoteOpenAIServer([ args = [
"--model",
MODEL_NAME,
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
"float16", "float16",
...@@ -19,7 +17,9 @@ def server(): ...@@ -19,7 +17,9 @@ def server():
"2048", "2048",
"--enforce-eager", "--enforce-eager",
"--engine-use-ray" "--engine-use-ray"
]) as remote_server: ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server yield remote_server
......
from vllm.utils import is_hip
from ..utils import compare_two_settings
def test_cpu_offload():
compare_two_settings("meta-llama/Llama-2-7b-hf", [],
["--cpu-offload-gb", "4"])
if not is_hip():
# compressed-tensors quantization is currently not supported in ROCm.
compare_two_settings(
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", [],
["--cpu-offload-gb", "1"])
...@@ -3,11 +3,7 @@ import gc ...@@ -3,11 +3,7 @@ import gc
import os import os
import sys import sys
from collections import UserList from collections import UserList
from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar
from functools import cached_property
from pathlib import Path
from typing import (Any, Dict, List, Literal, Optional, Tuple, TypedDict,
TypeVar)
import pytest import pytest
import torch import torch
...@@ -18,14 +14,16 @@ from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq, ...@@ -18,14 +14,16 @@ from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
AutoTokenizer, BatchEncoding) AutoTokenizer, BatchEncoding)
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import TokenizerPoolConfig from vllm.config import TokenizerPoolConfig
from vllm.connections import global_http_connection
from vllm.distributed import (destroy_distributed_environment, from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel) destroy_model_parallel)
from vllm.inputs import TextPrompt from vllm.inputs import TextPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal.utils import fetch_image
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import cuda_device_count_stateless, is_cpu from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
is_cpu)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -33,9 +31,6 @@ _TEST_DIR = os.path.dirname(__file__) ...@@ -33,9 +31,6 @@ _TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
_IMAGE_DIR = Path(_TEST_DIR) / "images"
"""You can use `.buildkite/download-images.sh` to download the assets."""
def _read_prompts(filename: str) -> List[str]: def _read_prompts(filename: str) -> List[str]:
with open(filename, "r") as f: with open(filename, "r") as f:
...@@ -43,24 +38,9 @@ def _read_prompts(filename: str) -> List[str]: ...@@ -43,24 +38,9 @@ def _read_prompts(filename: str) -> List[str]:
return prompts return prompts
@dataclass(frozen=True)
class ImageAsset:
name: Literal["stop_sign", "cherry_blossom", "boardwalk"]
@cached_property
def pil_image(self) -> Image.Image:
if self.name == "boardwalk":
return fetch_image(
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
)
return Image.open(_IMAGE_DIR / f"{self.name}.jpg")
class _ImageAssetPrompts(TypedDict): class _ImageAssetPrompts(TypedDict):
stop_sign: str stop_sign: str
cherry_blossom: str cherry_blossom: str
boardwalk: str
if sys.version_info < (3, 9): if sys.version_info < (3, 9):
...@@ -79,7 +59,6 @@ class _ImageAssets(_ImageAssetsBase): ...@@ -79,7 +59,6 @@ class _ImageAssets(_ImageAssetsBase):
super().__init__([ super().__init__([
ImageAsset("stop_sign"), ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"), ImageAsset("cherry_blossom"),
ImageAsset("boardwalk")
]) ])
def prompts(self, prompts: _ImageAssetPrompts) -> List[str]: def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
...@@ -89,16 +68,20 @@ class _ImageAssets(_ImageAssetsBase): ...@@ -89,16 +68,20 @@ class _ImageAssets(_ImageAssetsBase):
The order of the returned prompts matches the order of the The order of the returned prompts matches the order of the
assets when iterating through this object. assets when iterating through this object.
""" """
return [ return [prompts["stop_sign"], prompts["cherry_blossom"]]
prompts["stop_sign"], prompts["cherry_blossom"],
prompts["boardwalk"]
]
IMAGE_ASSETS = _ImageAssets() IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`.""" """Singleton instance of :class:`_ImageAssets`."""
@pytest.fixture(autouse=True)
def init_test_http_connection():
# pytest_asyncio may use a different event loop per test
# so we need to make sure the async client is created anew
global_http_connection.reuse_client = False
def cleanup(): def cleanup():
destroy_model_parallel() destroy_model_parallel()
destroy_distributed_environment() destroy_distributed_environment()
...@@ -150,12 +133,6 @@ def image_assets() -> _ImageAssets: ...@@ -150,12 +133,6 @@ def image_assets() -> _ImageAssets:
return IMAGE_ASSETS return IMAGE_ASSETS
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
}
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding) _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
...@@ -177,8 +154,7 @@ class HfRunner: ...@@ -177,8 +154,7 @@ class HfRunner:
is_vision_model: bool = False, is_vision_model: bool = False,
is_sparseml_model: bool = False, is_sparseml_model: bool = False,
) -> None: ) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model_name = model_name self.model_name = model_name
...@@ -590,6 +566,10 @@ def get_tokenizer_pool_config(tokenizer_group_type): ...@@ -590,6 +566,10 @@ def get_tokenizer_pool_config(tokenizer_group_type):
return TokenizerPoolConfig(pool_size=1, return TokenizerPoolConfig(pool_size=1,
pool_type="ray", pool_type="ray",
extra_config={}) extra_config={})
if isinstance(tokenizer_group_type, type):
return TokenizerPoolConfig(pool_size=1,
pool_type=tokenizer_group_type,
extra_config={})
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}") raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
......
...@@ -249,10 +249,13 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append, ...@@ -249,10 +249,13 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append,
# Expect consumed blocks to be new blocks required to support the new slots. # Expect consumed blocks to be new blocks required to support the new slots.
expected_consumed_blocks = len( expected_consumed_blocks = len(
list(
chunk_list( chunk_list(
list( list(
range(prompt_len + num_slots_to_append + num_lookahead_slots)), range(prompt_len + num_slots_to_append +
block_size)) - len(chunk_list(list(range(prompt_len)), block_size)) num_lookahead_slots)),
block_size))) - len(
list(chunk_list(list(range(prompt_len)), block_size)))
assert num_consumed_blocks == expected_consumed_blocks assert num_consumed_blocks == expected_consumed_blocks
......
...@@ -58,10 +58,10 @@ def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int, ...@@ -58,10 +58,10 @@ def test_allocate_immutable_block(num_cpu_blocks: int, num_gpu_blocks: int,
unique_token_ids = list( unique_token_ids = list(
range((num_cpu_blocks + num_gpu_blocks) * block_size)) range((num_cpu_blocks + num_gpu_blocks) * block_size))
gpu_token_ids = chunk_list(unique_token_ids[:num_gpu_blocks * block_size], gpu_token_ids = list(
block_size) chunk_list(unique_token_ids[:num_gpu_blocks * block_size], block_size))
cpu_token_ids = chunk_list(unique_token_ids[num_gpu_blocks * block_size:], cpu_token_ids = list(
block_size) chunk_list(unique_token_ids[num_gpu_blocks * block_size:], block_size))
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
......
...@@ -462,7 +462,7 @@ def test_prefill_schedule_max_lora(): ...@@ -462,7 +462,7 @@ def test_prefill_schedule_max_lora():
lora_request=LoRARequest( lora_request=LoRARequest(
lora_name=str(i), lora_name=str(i),
lora_int_id=i + 1, lora_int_id=i + 1,
lora_local_path="abc")) lora_path="abc"))
waiting.append(seq_group) waiting.append(seq_group)
# Add two more requests to verify lora is prioritized. # Add two more requests to verify lora is prioritized.
# 0: Lora, 1: Lora, 2: regular, 3: regular # 0: Lora, 1: Lora, 2: regular, 3: regular
...@@ -760,7 +760,7 @@ def test_schedule_swapped_max_loras(): ...@@ -760,7 +760,7 @@ def test_schedule_swapped_max_loras():
lora_request=LoRARequest( lora_request=LoRARequest(
lora_name=str(i), lora_name=str(i),
lora_int_id=i + 1, lora_int_id=i + 1,
lora_local_path="abc")) lora_path="abc"))
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1) append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out) scheduler._swap_out(seq_group, blocks_to_swap_out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment