Unverified Commit 09473ee4 authored by SangBin Cho's avatar SangBin Cho Committed by GitHub
Browse files

[mypy] Add mypy type annotation part 1 (#4006)

parent d4ec9ffb
...@@ -2,7 +2,8 @@ from typing import Dict, Optional ...@@ -2,7 +2,8 @@ from typing import Dict, Optional
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import * from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
JAISConfig, MPTConfig, RWConfig)
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
"chatglm": ChatGLMConfig, "chatglm": ChatGLMConfig,
......
...@@ -168,8 +168,8 @@ def _convert_tokens_to_string_with_added_encoders( ...@@ -168,8 +168,8 @@ def _convert_tokens_to_string_with_added_encoders(
# NOTE(woosuk): The following code is slow because it runs a for loop over # NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow # the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple. # even when the loop body is very simple.
sub_texts = [] sub_texts: List[str] = []
current_sub_text = [] current_sub_text: List[str] = []
all_special_tokens = set(tokenizer.all_special_tokens) all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens: for token in output_tokens:
if skip_special_tokens and token in all_special_tokens: if skip_special_tokens and token in all_special_tokens:
...@@ -263,6 +263,7 @@ def detokenize_incrementally( ...@@ -263,6 +263,7 @@ def detokenize_incrementally(
tokenizer, tokenizer,
all_input_ids[:-1], all_input_ids[:-1],
skip_special_tokens=skip_special_tokens) skip_special_tokens=skip_special_tokens)
assert prev_tokens is not None
# If the new token id is out of bounds, return an empty string. # If the new token id is out of bounds, return an empty string.
if new_token_id >= len(tokenizer): if new_token_id >= len(tokenizer):
...@@ -271,6 +272,8 @@ def detokenize_incrementally( ...@@ -271,6 +272,8 @@ def detokenize_incrementally(
# Put new_token_id in a list so skip_special_tokens is respected # Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens( new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens) [new_token_id], skip_special_tokens=skip_special_tokens)
if isinstance(new_tokens, str):
new_tokens = [new_tokens]
output_tokens = prev_tokens + new_tokens output_tokens = prev_tokens + new_tokens
# If this is the first iteration, return all tokens. # If this is the first iteration, return all tokens.
......
...@@ -5,7 +5,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, ...@@ -5,7 +5,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizers import * from vllm.transformers_utils.tokenizers import BaichuanTokenizer
from vllm.utils import make_async from vllm.utils import make_async
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -28,7 +28,7 @@ def get_cached_tokenizer( ...@@ -28,7 +28,7 @@ def get_cached_tokenizer(
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_len = len(tokenizer) tokenizer_len = len(tokenizer)
class CachedTokenizer(tokenizer.__class__): class CachedTokenizer(tokenizer.__class__): # type: ignore
@property @property
def all_special_ids(self): def all_special_ids(self):
......
...@@ -7,7 +7,7 @@ import time ...@@ -7,7 +7,7 @@ import time
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
from typing import Dict, Optional from typing import Any, Dict, Optional
from uuid import uuid4 from uuid import uuid4
import cpuinfo import cpuinfo
...@@ -124,7 +124,7 @@ class UsageMessage: ...@@ -124,7 +124,7 @@ class UsageMessage:
def report_usage(self, def report_usage(self,
model_architecture: str, model_architecture: str,
usage_context: UsageContext, usage_context: UsageContext,
extra_kvs: Dict[str, any] = None) -> None: extra_kvs: Optional[Dict[str, Any]] = None) -> None:
t = Thread(target=self._report_usage_worker, t = Thread(target=self._report_usage_worker,
args=(model_architecture, usage_context, extra_kvs or {}), args=(model_architecture, usage_context, extra_kvs or {}),
daemon=True) daemon=True)
...@@ -132,13 +132,13 @@ class UsageMessage: ...@@ -132,13 +132,13 @@ class UsageMessage:
def _report_usage_worker(self, model_architecture: str, def _report_usage_worker(self, model_architecture: str,
usage_context: UsageContext, usage_context: UsageContext,
extra_kvs: Dict[str, any]) -> None: extra_kvs: Dict[str, Any]) -> None:
self._report_usage_once(model_architecture, usage_context, extra_kvs) self._report_usage_once(model_architecture, usage_context, extra_kvs)
self._report_continous_usage() self._report_continous_usage()
def _report_usage_once(self, model_architecture: str, def _report_usage_once(self, model_architecture: str,
usage_context: UsageContext, usage_context: UsageContext,
extra_kvs: Dict[str, any]) -> None: extra_kvs: Dict[str, Any]) -> None:
# Platform information # Platform information
if torch.cuda.is_available(): if torch.cuda.is_available():
device_property = torch.cuda.get_device_properties(0) device_property = torch.cuda.get_device_properties(0)
......
...@@ -60,7 +60,7 @@ class LRUCache(Generic[T]): ...@@ -60,7 +60,7 @@ class LRUCache(Generic[T]):
def __len__(self) -> int: def __len__(self) -> int:
return len(self.cache) return len(self.cache)
def __getitem__(self, key: Hashable) -> T: def __getitem__(self, key: Hashable) -> Optional[T]:
return self.get(key) return self.get(key)
def __setitem__(self, key: Hashable, value: T) -> None: def __setitem__(self, key: Hashable, value: T) -> None:
...@@ -76,7 +76,7 @@ class LRUCache(Generic[T]): ...@@ -76,7 +76,7 @@ class LRUCache(Generic[T]):
key: Hashable, key: Hashable,
default_value: Optional[T] = None) -> Optional[T]: default_value: Optional[T] = None) -> Optional[T]:
if key in self.cache: if key in self.cache:
value = self.cache[key] value: Optional[T] = self.cache[key]
self.cache.move_to_end(key) self.cache.move_to_end(key)
else: else:
value = default_value value = default_value
...@@ -87,7 +87,7 @@ class LRUCache(Generic[T]): ...@@ -87,7 +87,7 @@ class LRUCache(Generic[T]):
self.cache.move_to_end(key) self.cache.move_to_end(key)
self._remove_old_if_needed() self._remove_old_if_needed()
def _on_remove(self, key: Hashable, value: T): def _on_remove(self, key: Hashable, value: Optional[T]):
pass pass
def remove_oldest(self): def remove_oldest(self):
...@@ -100,9 +100,11 @@ class LRUCache(Generic[T]): ...@@ -100,9 +100,11 @@ class LRUCache(Generic[T]):
while len(self.cache) > self.capacity: while len(self.cache) > self.capacity:
self.remove_oldest() self.remove_oldest()
def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T: def pop(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
run_on_remove = key in self.cache run_on_remove = key in self.cache
value = self.cache.pop(key, default_value) value: Optional[T] = self.cache.pop(key, default_value)
if run_on_remove: if run_on_remove:
self._on_remove(key, value) self._on_remove(key, value)
return value return value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment