Commit 7e5f909b authored by Baber's avatar Baber
Browse files

update types.

parent 9db56820
from __future__ import annotations
import abc
import hashlib
import json
import logging
import os
from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, TypeVar, Union
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, TypeVar
from tqdm import tqdm
......@@ -31,7 +34,7 @@ class LM(abc.ABC):
# set rank and world size to a single process, by default.
self._rank = 0
self._world_size = 1
self.cache_hook: "CacheHook" = CacheHook(None)
self.cache_hook: CacheHook = CacheHook(None)
@abc.abstractmethod
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
......@@ -137,7 +140,7 @@ class LM(abc.ABC):
@classmethod
def create_from_arg_string(
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
cls: type[T], arg_string: str, additional_config: dict | None = None
) -> T:
"""
Creates an instance of the LM class using the given argument string and additional config.
......@@ -156,7 +159,7 @@ class LM(abc.ABC):
@classmethod
def create_from_arg_obj(
cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
cls: type[T], arg_dict: dict, additional_config: dict | None = None
) -> T:
"""
Creates an instance of the LM class using the given arg_obj
......@@ -199,7 +202,7 @@ class LM(abc.ABC):
"To use this model with chat templates, please implement the 'tokenizer_name' property."
)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
def chat_template(self, chat_template: bool | str = False) -> str | None:
"""Returns the chat template structure for user/assistant messages if a template is provided.
This method is intended to be overridden in a subclass to define a specific chat template format.
For models that do not support chat templates, this method returns None by default.
......@@ -207,7 +210,7 @@ class LM(abc.ABC):
return ""
def set_cache_hook(self, cache_hook: "CacheHook") -> None:
def set_cache_hook(self, cache_hook: CacheHook) -> None:
self.cache_hook = cache_hook
......@@ -218,9 +221,9 @@ def hash_args(attr: str, args: Iterable[Any]) -> str:
class CacheHook:
def __init__(self, cachinglm: Optional["CachingLM"]) -> None:
def __init__(self, cachinglm: CachingLM | None) -> None:
if cachinglm is None:
self.dbdict: Optional["SqliteDict"] = None
self.dbdict: SqliteDict | None = None
return
self.dbdict = cachinglm.dbdict
......@@ -258,7 +261,7 @@ class CachingLM:
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
return lm_attr
def _fn(requests: list["Instance"]) -> list["Instance"]:
def _fn(requests: list[Instance]) -> list[Instance]:
res = []
remaining_reqs = []
warned = False
......@@ -313,7 +316,7 @@ class CachingLM:
return _fn
def get_cache_hook(self) -> "CacheHook":
def get_cache_hook(self) -> CacheHook:
return CacheHook(self)
......@@ -337,7 +340,9 @@ class TemplateLM(LM):
return self.eot_token_id
@abc.abstractmethod
def tok_encode(self, string: str, add_special_tokens=False, **kwargs) -> list[int]:
def tok_encode(
self, string: str, add_special_tokens: bool | None = None, **kwargs
) -> list[int]:
"""
Tokenize a string using the model's tokenizer and return a list of token IDs.
"""
......@@ -345,7 +350,7 @@ class TemplateLM(LM):
@abc.abstractmethod
def _loglikelihood_tokens(
self, requests: list["Instance"], **kwargs
self, requests: list[Instance], **kwargs
) -> list[tuple[float, bool]]:
pass
......@@ -399,7 +404,7 @@ class TemplateLM(LM):
return context_enc, continuation_enc
def loglikelihood(
self, requests: list["Instance"], disable_tqdm: bool = False
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[tuple[float, bool]]:
"""
Compute log-likelihood of generating continuations from contexts.
......@@ -456,7 +461,7 @@ class TemplateLM(LM):
def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
pass
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
def chat_template(self, chat_template: bool | str = False) -> str | None:
"""
Set and get the appropriate chat template for the model.
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment