"docs/vscode:/vscode.git/clone" did not exist on "262b76a09fafe15cff7642f3eee433fb903cf1d8"
Unverified Commit 09473ee4 authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[mypy] Add mypy type annotation part 1 (#4006)

parent d4ec9ffb
......@@ -2,7 +2,8 @@ from typing import Dict, Optional
from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import *
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
JAISConfig, MPTConfig, RWConfig)
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
"chatglm": ChatGLMConfig,
......
......@@ -168,8 +168,8 @@ def _convert_tokens_to_string_with_added_encoders(
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts = []
current_sub_text = []
sub_texts: List[str] = []
current_sub_text: List[str] = []
all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens:
if skip_special_tokens and token in all_special_tokens:
......@@ -263,6 +263,7 @@ def detokenize_incrementally(
tokenizer,
all_input_ids[:-1],
skip_special_tokens=skip_special_tokens)
assert prev_tokens is not None
# If the new token id is out of bounds, return an empty string.
if new_token_id >= len(tokenizer):
......@@ -271,6 +272,8 @@ def detokenize_incrementally(
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
if isinstance(new_tokens, str):
new_tokens = [new_tokens]
output_tokens = prev_tokens + new_tokens
# If this is the first iteration, return all tokens.
......
......@@ -5,7 +5,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizers import *
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
from vllm.utils import make_async
logger = init_logger(__name__)
......@@ -28,7 +28,7 @@ def get_cached_tokenizer(
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_len = len(tokenizer)
class CachedTokenizer(tokenizer.__class__):
class CachedTokenizer(tokenizer.__class__): # type: ignore
@property
def all_special_ids(self):
......
......@@ -7,7 +7,7 @@ import time
from enum import Enum
from pathlib import Path
from threading import Thread
from typing import Dict, Optional
from typing import Any, Dict, Optional
from uuid import uuid4
import cpuinfo
......@@ -124,7 +124,7 @@ class UsageMessage:
def report_usage(self,
model_architecture: str,
usage_context: UsageContext,
extra_kvs: Dict[str, any] = None) -> None:
extra_kvs: Optional[Dict[str, Any]] = None) -> None:
t = Thread(target=self._report_usage_worker,
args=(model_architecture, usage_context, extra_kvs or {}),
daemon=True)
......@@ -132,13 +132,13 @@ class UsageMessage:
def _report_usage_worker(self, model_architecture: str,
usage_context: UsageContext,
extra_kvs: Dict[str, any]) -> None:
extra_kvs: Dict[str, Any]) -> None:
self._report_usage_once(model_architecture, usage_context, extra_kvs)
self._report_continous_usage()
def _report_usage_once(self, model_architecture: str,
usage_context: UsageContext,
extra_kvs: Dict[str, any]) -> None:
extra_kvs: Dict[str, Any]) -> None:
# Platform information
if torch.cuda.is_available():
device_property = torch.cuda.get_device_properties(0)
......
......@@ -60,7 +60,7 @@ class LRUCache(Generic[T]):
def __len__(self) -> int:
return len(self.cache)
def __getitem__(self, key: Hashable) -> T:
def __getitem__(self, key: Hashable) -> Optional[T]:
return self.get(key)
def __setitem__(self, key: Hashable, value: T) -> None:
......@@ -76,7 +76,7 @@ class LRUCache(Generic[T]):
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
if key in self.cache:
value = self.cache[key]
value: Optional[T] = self.cache[key]
self.cache.move_to_end(key)
else:
value = default_value
......@@ -87,7 +87,7 @@ class LRUCache(Generic[T]):
self.cache.move_to_end(key)
self._remove_old_if_needed()
def _on_remove(self, key: Hashable, value: T):
def _on_remove(self, key: Hashable, value: Optional[T]):
pass
def remove_oldest(self):
......@@ -100,9 +100,11 @@ class LRUCache(Generic[T]):
while len(self.cache) > self.capacity:
self.remove_oldest()
def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T:
def pop(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
run_on_remove = key in self.cache
value = self.cache.pop(key, default_value)
value: Optional[T] = self.cache.pop(key, default_value)
if run_on_remove:
self._on_remove(key, value)
return value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment