Unverified Commit da1f7cc1 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[mypy] Enable following imports for some directories (#6681)

parent c32ab8be
......@@ -32,22 +32,17 @@ jobs:
pip install types-setuptools
- name: Mypy
run: |
mypy tests --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/attention --config-file pyproject.toml
mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/inputs --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/multimodal --config-file pyproject.toml
mypy vllm/platforms --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy tests --follow-imports skip
mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip
mypy vllm/prompt_adapter --follow-imports skip
mypy vllm/spec_decode --follow-imports skip
mypy vllm/worker --follow-imports skip
mypy
......@@ -96,23 +96,19 @@ echo 'vLLM yapf: Done'
# Run mypy
echo 'vLLM mypy:'
mypy tests --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/attention --config-file pyproject.toml
mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/multimodal --config-file pyproject.toml
mypy vllm/prompt_adapter --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy tests --follow-imports skip
mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip
mypy vllm/prompt_adapter --follow-imports skip
mypy vllm/spec_decode --follow-imports skip
mypy vllm/worker --follow-imports skip
mypy
# If git diff returns a file that is in the skip list, the file may be checked anyway:
......
......@@ -48,9 +48,23 @@ python_version = "3.8"
ignore_missing_imports = true
check_untyped_defs = true
follow_imports = "skip"
follow_imports = "silent"
files = "vllm"
# After fixing type errors resulting from follow_imports: "skip" -> "silent",
# move the directory here and remove it from format.sh and mypy.yaml
files = [
"vllm/*.py",
"vllm/adapter_commons",
"vllm/assets",
"vllm/inputs",
"vllm/logging",
"vllm/multimodal",
"vllm/platforms",
"vllm/server",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
......
......@@ -239,7 +239,7 @@ def cutlass_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
......
......@@ -25,27 +25,33 @@ class ipex_ops:
x2 = x2.reshape(num, d)
return x1, x2
@staticmethod
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.silu_mul(x1, x2, out)
@staticmethod
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "none")
@staticmethod
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")
@staticmethod
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x))
@staticmethod
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x))
# TODO add implementation of gelu_quick here
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
@staticmethod
def paged_attention_v1(
out: torch.Tensor,
query: torch.Tensor,
......@@ -78,12 +84,21 @@ class ipex_ops:
).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace
torch.xpu.paged_attention_v1(out, query.contiguous(),
key_cache.view_as(value_cache),
value_cache, head_mapping, scale,
block_tables, context_lens, block_size,
max_context_len, alibi_slopes)
torch.xpu.paged_attention_v1( # type: ignore
out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)
@staticmethod
def paged_attention_v2(
out: torch.Tensor,
exp_sum: torch.Tensor,
......@@ -119,13 +134,24 @@ class ipex_ops:
).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace
torch.xpu.paged_attention_v2(out, exp_sum, max_logits, tmp_out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache, head_mapping, block_tables,
context_lens, scale, block_size,
max_context_len, alibi_slopes)
torch.xpu.paged_attention_v2( # type: ignore
out,
exp_sum,
max_logits,
tmp_out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
head_mapping,
block_tables,
context_lens,
scale,
block_size,
max_context_len,
alibi_slopes,
)
@staticmethod
def rotary_embedding(
positions: torch.Tensor, # [batch_size, seq_len]
query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
......@@ -158,6 +184,7 @@ class ipex_ops:
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions)
@staticmethod
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool,
......@@ -189,17 +216,20 @@ class ipex_ops:
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions)
@staticmethod
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
out.copy_(tmp)
@staticmethod
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
epsilon, True)
input.copy_(tmp)
@staticmethod
def varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
......@@ -222,6 +252,7 @@ class ipex_ops:
softmax_scale, zero_tensors,
is_causal, return_softmax, gen_)
@staticmethod
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
......@@ -240,8 +271,13 @@ class ipex_ops:
def copy_blocks(key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor],
block_mapping: torch.Tensor) -> None:
torch.xpu.copy_blocks(key_caches, value_caches, block_mapping)
torch.xpu.copy_blocks( # type: ignore
key_caches,
value_caches,
block_mapping,
)
@staticmethod
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None:
torch.xpu.swap_blocks(src, dst, block_mapping)
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
......@@ -31,7 +31,7 @@ class AdapterLRUCache(LRUCache[T]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn
def _on_remove(self, key: Hashable, value: T):
def _on_remove(self, key: Hashable, value: Optional[T]):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)
......@@ -59,46 +59,46 @@ class AdapterModelManager(ABC):
@property
@abstractmethod
def adapter_slots(self):
...
def adapter_slots(self) -> int:
raise NotImplementedError
@property
@abstractmethod
def capacity(self):
...
def capacity(self) -> int:
raise NotImplementedError
@abstractmethod
def activate_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError
@abstractmethod
def deactivate_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError
@abstractmethod
def add_adapter(self, adapter: Any) -> bool:
...
raise NotImplementedError
@abstractmethod
def set_adapter_mapping(self, mapping: Any) -> None:
...
raise NotImplementedError
@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError
@abstractmethod
def remove_all_adapters(self):
...
def remove_all_adapters(self) -> None:
raise NotImplementedError
@abstractmethod
def get_adapter(self, adapter_id: int) -> Optional[Any]:
...
raise NotImplementedError
@abstractmethod
def list_adapters(self) -> Dict[int, Any]:
...
raise NotImplementedError
@abstractmethod
def pin_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError
from abc import abstractmethod
from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass
class AdapterRequest:
class AdapterRequest(ABC):
"""
Base class for adapter requests.
"""
@property
@abstractmethod
def adapter_id(self):
...
def adapter_id(self) -> int:
raise NotImplementedError
def __post_init__(self):
def __post_init__(self) -> None:
if self.adapter_id < 1:
raise ValueError(f"id must be > 0, got {self.adapter_id}")
......
......@@ -12,25 +12,25 @@ class AbstractWorkerManager(ABC):
@property
@abstractmethod
def is_enabled(self) -> bool:
...
raise NotImplementedError
@abstractmethod
def set_active_adapters(self, requests: Set[Any],
mapping: Optional[Any]) -> None:
...
raise NotImplementedError
@abstractmethod
def add_adapter(self, adapter_request: Any) -> bool:
...
raise NotImplementedError
@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError
@abstractmethod
def remove_all_adapters(self):
...
def remove_all_adapters(self) -> None:
raise NotImplementedError
@abstractmethod
def list_adapters(self) -> Set[int]:
...
raise NotImplementedError
......@@ -724,7 +724,7 @@ class ParallelConfig:
backend)
self._verify_args()
self.rank = 0
self.rank: int = 0
@property
def use_ray(self) -> bool:
......@@ -850,6 +850,7 @@ class SchedulerConfig:
class DeviceConfig:
device: Optional[torch.device]
def __init__(self, device: str = "auto") -> None:
if device == "auto":
......
......@@ -5,8 +5,6 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union
from transformers import PreTrainedTokenizer
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
......@@ -40,7 +38,8 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
from vllm.transformers_utils.tokenizer_group import (AnyTokenizer,
BaseTokenizerGroup,
get_tokenizer_group)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
......@@ -477,13 +476,12 @@ class LLMEngine:
return self.tokenizer
def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)
......
......@@ -5,7 +5,6 @@ from http import HTTPStatus
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Annotated
from vllm.config import ModelConfig
......@@ -30,6 +29,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer_group import AnyTokenizer
logger = init_logger(__name__)
......@@ -49,8 +49,6 @@ class LoRAModulePath:
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
EmbeddingRequest, TokenizeRequest]
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
class TextTokensPrompt(TypedDict):
prompt: str
......
......@@ -4,9 +4,10 @@ import asyncio
import os
import signal
import sys
from typing import Optional
from typing import List, Optional
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import make_arg_parser
......@@ -63,15 +64,14 @@ def complete(model_name: str, client: OpenAI) -> None:
def chat(system_prompt: Optional[str], model_name: str,
client: OpenAI) -> None:
conversation = []
conversation: List[ChatCompletionMessageParam] = []
if system_prompt is not None:
conversation.append({"role": "system", "content": system_prompt})
print("Please enter a message for the chat model:")
while True:
input_message = input("> ")
message = {"role": "user", "content": input_message}
conversation.append(message)
conversation.append({"role": "user", "content": input_message})
chat_completion = client.chat.completions.create(model=model_name,
messages=conversation)
......@@ -79,7 +79,7 @@ def chat(system_prompt: Optional[str], model_name: str,
response_message = chat_completion.choices[0].message
output = response_message.content
conversation.append(response_message)
conversation.append(response_message) # type: ignore
print(output)
......
......@@ -37,6 +37,8 @@ class Detokenizer:
The prompt logprobs with the decoded tokens.
"""
prms = seq_group.sampling_params
assert prms is not None
# We can pick any sequence for the prompt.
seq = next(iter(seq_group.seqs_dict.values()))
# Only prompt, without the generated token.
......
......@@ -2,10 +2,9 @@ from typing import Optional, Type
from vllm.config import TokenizerPoolConfig
from vllm.executor.ray_utils import ray
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
from .tokenizer_group import TokenizerGroup
if ray:
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
......@@ -34,4 +33,4 @@ def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs)
__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"]
__all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"]
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import List, Optional, Union
from transformers import PreTrainedTokenizer
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
class BaseTokenizerGroup(ABC):
"""A group of tokenizers that can be used for LoRA adapters."""
......@@ -47,17 +49,17 @@ class BaseTokenizerGroup(ABC):
@abstractmethod
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
"""Get a tokenizer for a LoRA request."""
pass
@abstractmethod
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
"""Get a tokenizer for a LoRA request."""
pass
......
......@@ -6,18 +6,16 @@ try:
from ray.exceptions import ActorDiedError
except ImportError:
# For older versions of Ray
from ray.exceptions import RayActorError as ActorDiedError
from ray.exceptions import RayActorError as ActorDiedError # type: ignore
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig
from vllm.executor.ray_utils import ray
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
from .tokenizer_group import TokenizerGroup
logger = init_logger(__name__)
......@@ -67,7 +65,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
**self._tokenizer_config, )
self._ray_tokenizer_group_cls = ray.remote(
self._worker_cls).options(**ray_actor_options)
self._worker_cls).options(**ray_actor_options) # type: ignore
self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)]
self._idle_actors: Optional[asyncio.Queue] = None
......@@ -83,8 +81,10 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
return len(self.tokenizer_actors)
def ping(self):
return ray.get(
[actor.ping.remote() for actor in self.tokenizer_actors])
return ray.get([
actor.ping.remote() # type: ignore
for actor in self.tokenizer_actors
])
def _ensure_queue_initialized(self):
if self._idle_actors is None:
......@@ -208,15 +208,15 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
return self._local_tokenizer_group.get_max_input_len(lora_request)
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self._local_tokenizer_group.get_lora_tokenizer(lora_request)
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return await self._local_tokenizer_group.get_lora_tokenizer_async(
lora_request)
......
from typing import List, Optional
from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import (get_lora_tokenizer,
get_lora_tokenizer_async,
get_tokenizer)
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.utils import LRUCache
from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup
class TokenizerGroup(BaseTokenizerGroup):
"""A group of tokenizers that can be used for LoRA adapters."""
......@@ -22,8 +20,8 @@ class TokenizerGroup(BaseTokenizerGroup):
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
capacity=max_num_seqs) if enable_lora else None
self.lora_tokenizers = LRUCache[AnyTokenizer](
capacity=max_num_seqs if enable_lora else 0)
@classmethod
def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig],
......@@ -41,7 +39,7 @@ class TokenizerGroup(BaseTokenizerGroup):
return self.max_input_length
def _raise_if_input_too_long(self,
encoded_tokens: List[str],
encoded_tokens: List[int],
lora_request: Optional[LoRARequest] = None):
input_length = len(encoded_tokens)
if lora_request:
......@@ -72,9 +70,9 @@ class TokenizerGroup(BaseTokenizerGroup):
return ret
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
......@@ -83,12 +81,12 @@ class TokenizerGroup(BaseTokenizerGroup):
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
return self.lora_tokenizers[lora_request.lora_int_id]
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
......@@ -97,4 +95,4 @@ class TokenizerGroup(BaseTokenizerGroup):
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
return self.lora_tokenizers[lora_request.lora_int_id]
......@@ -94,8 +94,10 @@ class LRUCache(Generic[T]):
def __len__(self) -> int:
return len(self.cache)
def __getitem__(self, key: Hashable) -> Optional[T]:
return self.get(key)
def __getitem__(self, key: Hashable) -> T:
value = self.cache[key] # Raise KeyError if not exists
self.cache.move_to_end(key)
return value
def __setitem__(self, key: Hashable, value: T) -> None:
self.put(key, value)
......@@ -109,8 +111,9 @@ class LRUCache(Generic[T]):
def get(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
value: Optional[T]
if key in self.cache:
value: Optional[T] = self.cache[key]
value = self.cache[key]
self.cache.move_to_end(key)
else:
value = default_value
......@@ -590,8 +593,8 @@ class CudaMemoryProfiler:
torch.cuda.reset_peak_memory_stats(self.device)
mem = torch.cuda.max_memory_allocated(self.device)
elif is_xpu():
torch.xpu.reset_peak_memory_stats(self.device)
mem = torch.xpu.max_memory_allocated(self.device)
torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
return mem
def __enter__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment