Unverified Commit c2b4a1bc authored by Michael Feil's avatar Michael Feil Committed by GitHub
Browse files

[Doc] Add typing hints / mypy types cleanup (#3816)


Co-authored-by: default avatarRoger Wang <136131678+ywang96@users.noreply.github.com>
parent e46a60aa
...@@ -27,8 +27,8 @@ class RequestFuncInput: ...@@ -27,8 +27,8 @@ class RequestFuncInput:
class RequestFuncOutput: class RequestFuncOutput:
generated_text: str = "" generated_text: str = ""
success: bool = False success: bool = False
latency: float = 0 latency: float = 0.0
ttft: float = 0 # Time to first token ttft: float = 0.0 # Time to first token
itl: List[float] = field( itl: List[float] = field(
default_factory=list) # List of inter-token latencies default_factory=list) # List of inter-token latencies
prompt_len: int = 0 prompt_len: int = 0
...@@ -58,23 +58,24 @@ async def async_request_tgi( ...@@ -58,23 +58,24 @@ async def async_request_tgi(
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
ttft = 0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload) as response: async with session.post(url=api_url, json=payload) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk_bytes in response.content:
chunk = chunk.strip() chunk_bytes = chunk_bytes.strip()
if not chunk: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk.decode("utf-8"), "data:") chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")
data = json.loads(chunk) data = json.loads(chunk)
timestamp = time.perf_counter() timestamp = time.perf_counter()
# First token # First token
if ttft == 0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
...@@ -119,23 +120,24 @@ async def async_request_trt_llm( ...@@ -119,23 +120,24 @@ async def async_request_trt_llm(
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
ttft = 0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload) as response: async with session.post(url=api_url, json=payload) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk_bytes in response.content:
chunk = chunk.strip() chunk_bytes = chunk_bytes.strip()
if not chunk: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk.decode("utf-8"), "data:") chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")
data = json.loads(chunk) data = json.loads(chunk)
timestamp = time.perf_counter() timestamp = time.perf_counter()
# First token # First token
if ttft == 0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
...@@ -151,7 +153,7 @@ async def async_request_trt_llm( ...@@ -151,7 +153,7 @@ async def async_request_trt_llm(
output.success = True output.success = True
else: else:
output.error = response.reason output.error = response.reason or ""
output.success = False output.success = False
except Exception: except Exception:
output.success = False output.success = False
...@@ -195,7 +197,7 @@ async def async_request_deepspeed_mii( ...@@ -195,7 +197,7 @@ async def async_request_deepspeed_mii(
output.generated_text = parsed_resp["text"][0] output.generated_text = parsed_resp["text"][0]
output.success = True output.success = True
else: else:
output.error = response.reason output.error = response.reason or ""
output.success = False output.success = False
except Exception: except Exception:
output.success = False output.success = False
...@@ -234,19 +236,20 @@ async def async_request_openai_completions( ...@@ -234,19 +236,20 @@ async def async_request_openai_completions(
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
generated_text = "" generated_text = ""
ttft = 0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload,
headers=headers) as response: headers=headers) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk_bytes in response.content:
chunk = chunk.strip() chunk_bytes = chunk_bytes.strip()
if not chunk: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk.decode("utf-8"), "data: ") chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
if chunk == "[DONE]": if chunk == "[DONE]":
latency = time.perf_counter() - st latency = time.perf_counter() - st
else: else:
...@@ -255,7 +258,7 @@ async def async_request_openai_completions( ...@@ -255,7 +258,7 @@ async def async_request_openai_completions(
if data["choices"][0]["text"]: if data["choices"][0]["text"]:
timestamp = time.perf_counter() timestamp = time.perf_counter()
# First token # First token
if ttft == 0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
...@@ -315,19 +318,20 @@ async def async_request_openai_chat_completions( ...@@ -315,19 +318,20 @@ async def async_request_openai_chat_completions(
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
generated_text = "" generated_text = ""
ttft = 0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload,
headers=headers) as response: headers=headers) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk_bytes in response.content:
chunk = chunk.strip() chunk_bytes = chunk_bytes.strip()
if not chunk: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk.decode("utf-8"), "data: ") chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
if chunk == "[DONE]": if chunk == "[DONE]":
latency = time.perf_counter() - st latency = time.perf_counter() - st
else: else:
...@@ -337,7 +341,7 @@ async def async_request_openai_chat_completions( ...@@ -337,7 +341,7 @@ async def async_request_openai_chat_completions(
delta = data["choices"][0]["delta"] delta = data["choices"][0]["delta"]
if delta.get("content", None): if delta.get("content", None):
# First token # First token
if ttft == 0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
...@@ -354,7 +358,7 @@ async def async_request_openai_chat_completions( ...@@ -354,7 +358,7 @@ async def async_request_openai_chat_completions(
output.success = True output.success = True
output.latency = latency output.latency = latency
else: else:
output.error = response.reason output.error = response.reason or ""
output.success = False output.success = False
except Exception: except Exception:
output.success = False output.success = False
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
import logging import logging
import sys import sys
from typing import List
from sphinx.ext import autodoc from sphinx.ext import autodoc
...@@ -45,7 +46,7 @@ templates_path = ['_templates'] ...@@ -45,7 +46,7 @@ templates_path = ['_templates']
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path. # This pattern also affects html_static_path and html_extra_path.
exclude_patterns = [] exclude_patterns: List[str] = []
# Exclude the prompt "$" when copying code # Exclude the prompt "$" when copying code
copybutton_prompt_text = r"\$ " copybutton_prompt_text = r"\$ "
......
...@@ -5,7 +5,7 @@ import re ...@@ -5,7 +5,7 @@ import re
import subprocess import subprocess
import sys import sys
from shutil import which from shutil import which
from typing import List from typing import Dict, List
import torch import torch
from packaging.version import Version, parse from packaging.version import Version, parse
...@@ -52,7 +52,7 @@ class CMakeExtension(Extension): ...@@ -52,7 +52,7 @@ class CMakeExtension(Extension):
class cmake_build_ext(build_ext): class cmake_build_ext(build_ext):
# A dict of extension directories that have been configured. # A dict of extension directories that have been configured.
did_config = {} did_config: Dict[str, bool] = {}
# #
# Determine number of compilation jobs and optionally nvcc compile threads. # Determine number of compilation jobs and optionally nvcc compile threads.
...@@ -261,6 +261,7 @@ def get_nvcc_cuda_version() -> Version: ...@@ -261,6 +261,7 @@ def get_nvcc_cuda_version() -> Version:
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
""" """
assert CUDA_HOME is not None, "CUDA_HOME is not set"
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
universal_newlines=True) universal_newlines=True)
output = nvcc_output.split() output = nvcc_output.split()
......
from abc import ABC, abstractmethod, abstractproperty from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Protocol from typing import Dict, FrozenSet, List, Optional, Protocol
from vllm.utils import Device from vllm.utils import Device
...@@ -10,23 +10,28 @@ class Block(ABC): ...@@ -10,23 +10,28 @@ class Block(ABC):
def append_token_ids(self, token_ids: List[int]) -> None: def append_token_ids(self, token_ids: List[int]) -> None:
pass pass
@abstractproperty @property
@abstractmethod
def block_id(self) -> Optional[int]: def block_id(self) -> Optional[int]:
pass pass
@abstractproperty @property
@abstractmethod
def token_ids(self) -> List[int]: def token_ids(self) -> List[int]:
pass pass
@abstractproperty @property
@abstractmethod
def num_empty_slots(self) -> int: def num_empty_slots(self) -> int:
pass pass
@abstractproperty @property
@abstractmethod
def is_full(self) -> bool: def is_full(self) -> bool:
pass pass
@abstractproperty @property
@abstractmethod
def prev_block(self) -> Optional["Block"]: def prev_block(self) -> Optional["Block"]:
pass pass
...@@ -47,12 +52,13 @@ class Block(ABC): ...@@ -47,12 +52,13 @@ class Block(ABC):
class BlockAllocator(ABC): class BlockAllocator(ABC):
@abstractmethod @abstractmethod
def allocate_mutable(self, prev_block: Optional[Block]) -> Block: def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
pass pass
@abstractmethod @abstractmethod
def allocate_immutable(self, prev_block: Optional[Block], def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block: token_ids: List[int], device: Device) -> Block:
pass pass
@abstractmethod @abstractmethod
...@@ -64,11 +70,12 @@ class BlockAllocator(ABC): ...@@ -64,11 +70,12 @@ class BlockAllocator(ABC):
pass pass
@abstractmethod @abstractmethod
def get_num_free_blocks(self) -> int: def get_num_free_blocks(self, device: Device) -> int:
pass pass
@abstractproperty @property
def all_block_ids(self) -> frozenset[int]: @abstractmethod
def all_block_ids(self) -> FrozenSet[int]:
pass pass
@abstractmethod @abstractmethod
......
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List from typing import Dict, List, Protocol
import numpy as np import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
...@@ -119,6 +119,12 @@ class Stats: ...@@ -119,6 +119,12 @@ class Stats:
time_e2e_requests: List[float] time_e2e_requests: List[float]
class SupportsMetricsInfo(Protocol):
def metrics_info(self) -> Dict[str, str]:
...
class StatLogger: class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout.""" """StatLogger is used LLMEngine to log to Promethus and Stdout."""
...@@ -135,7 +141,7 @@ class StatLogger: ...@@ -135,7 +141,7 @@ class StatLogger:
self.labels = labels self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys())) self.metrics = Metrics(labelnames=list(labels.keys()))
def info(self, type: str, obj: object) -> None: def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config": if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info()) self.metrics.info_cache_config.info(obj.metrics_info())
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import logging import logging
import os import os
import sys import sys
from typing import Optional
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
...@@ -26,7 +27,7 @@ class NewLineFormatter(logging.Formatter): ...@@ -26,7 +27,7 @@ class NewLineFormatter(logging.Formatter):
_root_logger = logging.getLogger("vllm") _root_logger = logging.getLogger("vllm")
_default_handler = None _default_handler: Optional[logging.Handler] = None
def _setup_logger(): def _setup_logger():
...@@ -55,7 +56,12 @@ def init_logger(name: str): ...@@ -55,7 +56,12 @@ def init_logger(name: str):
# Use the same settings as above for root logger # Use the same settings as above for root logger
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG")) logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))
if VLLM_CONFIGURE_LOGGING: if VLLM_CONFIGURE_LOGGING:
if _default_handler is None:
raise ValueError(
"_default_handler is not set up. This should never happen!"
" Please open an issue on Github.")
logger.addHandler(_default_handler) logger.addHandler(_default_handler)
logger.propagate = False logger.propagate = False
return logger return logger
...@@ -247,11 +247,12 @@ def _yarn_find_correction_dim(num_rotations: int, ...@@ -247,11 +247,12 @@ def _yarn_find_correction_dim(num_rotations: int,
# Find dim range bounds based on rotations # Find dim range bounds based on rotations
def _yarn_find_correction_range(low_rot: int, def _yarn_find_correction_range(
high_rot: int, low_rot: int,
dim: int, high_rot: int,
base: float = 10000, dim: int,
max_position_embeddings: int = 2048) -> int: base: float = 10000,
max_position_embeddings: int = 2048) -> Tuple[int, int]:
low = math.floor( low = math.floor(
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil( high = math.ceil(
...@@ -293,8 +294,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -293,8 +294,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
*, *,
extrapolation_factor: float = 1, extrapolation_factor: float = 1,
attn_factor: float = 1, attn_factor: float = 1,
beta_fast: float = 32, beta_fast: int = 32,
beta_slow: float = 1, beta_slow: int = 1,
) -> None: ) -> None:
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor self.extrapolation_factor = extrapolation_factor
......
from typing import Optional from typing import Dict, Optional
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import * from vllm.transformers_utils.configs import *
_CONFIG_REGISTRY = { _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
"chatglm": ChatGLMConfig, "chatglm": ChatGLMConfig,
"dbrx": DbrxConfig, "dbrx": DbrxConfig,
"mpt": MPTConfig, "mpt": MPTConfig,
......
...@@ -12,7 +12,7 @@ from transformers.utils import logging ...@@ -12,7 +12,7 @@ from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore
class DbrxAttentionConfig(PretrainedConfig): class DbrxAttentionConfig(PretrainedConfig):
......
...@@ -16,11 +16,11 @@ logger = logging.get_logger(__name__) ...@@ -16,11 +16,11 @@ logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = { # type: ignore
"vocab_file": {}, "vocab_file": {},
"tokenizer_file": {}, "tokenizer_file": {},
} }
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} # type: ignore
class BaichuanTokenizer(PreTrainedTokenizer): class BaichuanTokenizer(PreTrainedTokenizer):
...@@ -148,9 +148,9 @@ class BaichuanTokenizer(PreTrainedTokenizer): ...@@ -148,9 +148,9 @@ class BaichuanTokenizer(PreTrainedTokenizer):
`Tuple(str)`: Paths to the files saved. `Tuple(str)`: Paths to the files saved.
""" """
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) " raise ValueError(f"Vocabulary path ({save_directory}) "
"should be a directory") "should be a directory")
return
out_vocab_file = os.path.join( out_vocab_file = os.path.join(
save_directory, save_directory,
(filename_prefix + "-" if filename_prefix else "") + (filename_prefix + "-" if filename_prefix else "") +
......
...@@ -294,7 +294,7 @@ def create_kv_caches_with_random( ...@@ -294,7 +294,7 @@ def create_kv_caches_with_random(
head_size: int, head_size: int,
cache_dtype: Optional[Union[str, torch.dtype]], cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None, model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: Optional[int] = 0, seed: int = 0,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
...@@ -400,7 +400,7 @@ class CudaMemoryProfiler: ...@@ -400,7 +400,7 @@ class CudaMemoryProfiler:
gc.collect() gc.collect()
def str_to_int_tuple(s: str) -> Tuple[int]: def str_to_int_tuple(s: str) -> Tuple[int, ...]:
"""Convert a string to a tuple of integers.""" """Convert a string to a tuple of integers."""
try: try:
return tuple(map(int, s.split(","))) return tuple(map(int, s.split(",")))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment