Unverified Commit c2b4a1bc authored by Michael Feil's avatar Michael Feil Committed by GitHub
Browse files

[Doc] Add typing hints / mypy types cleanup (#3816)


Co-authored-by: default avatarRoger Wang <136131678+ywang96@users.noreply.github.com>
parent e46a60aa
......@@ -27,8 +27,8 @@ class RequestFuncInput:
class RequestFuncOutput:
generated_text: str = ""
success: bool = False
latency: float = 0
ttft: float = 0 # Time to first token
latency: float = 0.0
ttft: float = 0.0 # Time to first token
itl: List[float] = field(
default_factory=list) # List of inter-token latencies
prompt_len: int = 0
......@@ -58,23 +58,24 @@ async def async_request_tgi(
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")
data = json.loads(chunk)
timestamp = time.perf_counter()
# First token
if ttft == 0:
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
......@@ -119,23 +120,24 @@ async def async_request_trt_llm(
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")
data = json.loads(chunk)
timestamp = time.perf_counter()
# First token
if ttft == 0:
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
......@@ -151,7 +153,7 @@ async def async_request_trt_llm(
output.success = True
else:
output.error = response.reason
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
......@@ -195,7 +197,7 @@ async def async_request_deepspeed_mii(
output.generated_text = parsed_resp["text"][0]
output.success = True
else:
output.error = response.reason
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
......@@ -234,19 +236,20 @@ async def async_request_openai_completions(
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
......@@ -255,7 +258,7 @@ async def async_request_openai_completions(
if data["choices"][0]["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0:
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
......@@ -315,19 +318,20 @@ async def async_request_openai_chat_completions(
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
......@@ -337,7 +341,7 @@ async def async_request_openai_chat_completions(
delta = data["choices"][0]["delta"]
if delta.get("content", None):
# First token
if ttft == 0:
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
......@@ -354,7 +358,7 @@ async def async_request_openai_chat_completions(
output.success = True
output.latency = latency
else:
output.error = response.reason
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
......
......@@ -12,6 +12,7 @@
import logging
import sys
from typing import List
from sphinx.ext import autodoc
......@@ -45,7 +46,7 @@ templates_path = ['_templates']
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
exclude_patterns: List[str] = []
# Exclude the prompt "$" when copying code
copybutton_prompt_text = r"\$ "
......
......@@ -5,7 +5,7 @@ import re
import subprocess
import sys
from shutil import which
from typing import List
from typing import Dict, List
import torch
from packaging.version import Version, parse
......@@ -52,7 +52,7 @@ class CMakeExtension(Extension):
class cmake_build_ext(build_ext):
# A dict of extension directories that have been configured.
did_config = {}
did_config: Dict[str, bool] = {}
#
# Determine number of compilation jobs and optionally nvcc compile threads.
......@@ -261,6 +261,7 @@ def get_nvcc_cuda_version() -> Version:
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
"""
assert CUDA_HOME is not None, "CUDA_HOME is not set"
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
universal_newlines=True)
output = nvcc_output.split()
......
from abc import ABC, abstractmethod, abstractproperty
from typing import Dict, List, Optional, Protocol
from abc import ABC, abstractmethod
from typing import Dict, FrozenSet, List, Optional, Protocol
from vllm.utils import Device
......@@ -10,23 +10,28 @@ class Block(ABC):
def append_token_ids(self, token_ids: List[int]) -> None:
pass
@abstractproperty
@property
@abstractmethod
def block_id(self) -> Optional[int]:
pass
@abstractproperty
@property
@abstractmethod
def token_ids(self) -> List[int]:
pass
@abstractproperty
@property
@abstractmethod
def num_empty_slots(self) -> int:
pass
@abstractproperty
@property
@abstractmethod
def is_full(self) -> bool:
pass
@abstractproperty
@property
@abstractmethod
def prev_block(self) -> Optional["Block"]:
pass
......@@ -47,12 +52,13 @@ class Block(ABC):
class BlockAllocator(ABC):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
pass
@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
token_ids: List[int], device: Device) -> Block:
pass
@abstractmethod
......@@ -64,11 +70,12 @@ class BlockAllocator(ABC):
pass
@abstractmethod
def get_num_free_blocks(self) -> int:
def get_num_free_blocks(self, device: Device) -> int:
pass
@abstractproperty
def all_block_ids(self) -> frozenset[int]:
@property
@abstractmethod
def all_block_ids(self) -> FrozenSet[int]:
pass
@abstractmethod
......
import time
from dataclasses import dataclass
from typing import Dict, List
from typing import Dict, List, Protocol
import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
......@@ -119,6 +119,12 @@ class Stats:
time_e2e_requests: List[float]
class SupportsMetricsInfo(Protocol):
def metrics_info(self) -> Dict[str, str]:
...
class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
......@@ -135,7 +141,7 @@ class StatLogger:
self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys()))
def info(self, type: str, obj: object) -> None:
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info())
......
......@@ -4,6 +4,7 @@
import logging
import os
import sys
from typing import Optional
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
......@@ -26,7 +27,7 @@ class NewLineFormatter(logging.Formatter):
_root_logger = logging.getLogger("vllm")
_default_handler = None
_default_handler: Optional[logging.Handler] = None
def _setup_logger():
......@@ -55,7 +56,12 @@ def init_logger(name: str):
# Use the same settings as above for root logger
logger = logging.getLogger(name)
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))
if VLLM_CONFIGURE_LOGGING:
if _default_handler is None:
raise ValueError(
"_default_handler is not set up. This should never happen!"
" Please open an issue on Github.")
logger.addHandler(_default_handler)
logger.propagate = False
return logger
......@@ -247,11 +247,12 @@ def _yarn_find_correction_dim(num_rotations: int,
# Find dim range bounds based on rotations
def _yarn_find_correction_range(low_rot: int,
def _yarn_find_correction_range(
low_rot: int,
high_rot: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> int:
max_position_embeddings: int = 2048) -> Tuple[int, int]:
low = math.floor(
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(
......@@ -293,8 +294,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: float = 32,
beta_slow: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
......
from typing import Optional
from typing import Dict, Optional
from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import *
_CONFIG_REGISTRY = {
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
"chatglm": ChatGLMConfig,
"dbrx": DbrxConfig,
"mpt": MPTConfig,
......
......@@ -12,7 +12,7 @@ from transformers.utils import logging
logger = logging.get_logger(__name__)
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore
class DbrxAttentionConfig(PretrainedConfig):
......
......@@ -16,11 +16,11 @@ logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
PRETRAINED_VOCAB_FILES_MAP = {
PRETRAINED_VOCAB_FILES_MAP = { # type: ignore
"vocab_file": {},
"tokenizer_file": {},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} # type: ignore
class BaichuanTokenizer(PreTrainedTokenizer):
......@@ -148,9 +148,9 @@ class BaichuanTokenizer(PreTrainedTokenizer):
`Tuple(str)`: Paths to the files saved.
"""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) "
raise ValueError(f"Vocabulary path ({save_directory}) "
"should be a directory")
return
out_vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") +
......
......@@ -294,7 +294,7 @@ def create_kv_caches_with_random(
head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: Optional[int] = 0,
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
......@@ -400,7 +400,7 @@ class CudaMemoryProfiler:
gc.collect()
def str_to_int_tuple(s: str) -> Tuple[int]:
def str_to_int_tuple(s: str) -> Tuple[int, ...]:
"""Convert a string to a tuple of integers."""
try:
return tuple(map(int, s.split(",")))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment