Commit 7e5f909b authored by Baber's avatar Baber
Browse files

update types.

parent 9db56820
from __future__ import annotations
import abc import abc
import hashlib import hashlib
import json import json
import logging import logging
import os import os
from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, TypeVar, Union from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, TypeVar
from tqdm import tqdm from tqdm import tqdm
...@@ -31,7 +34,7 @@ class LM(abc.ABC): ...@@ -31,7 +34,7 @@ class LM(abc.ABC):
# set rank and world size to a single process, by default. # set rank and world size to a single process, by default.
self._rank = 0 self._rank = 0
self._world_size = 1 self._world_size = 1
self.cache_hook: "CacheHook" = CacheHook(None) self.cache_hook: CacheHook = CacheHook(None)
@abc.abstractmethod @abc.abstractmethod
def loglikelihood(self, requests) -> list[tuple[float, bool]]: def loglikelihood(self, requests) -> list[tuple[float, bool]]:
...@@ -137,7 +140,7 @@ class LM(abc.ABC): ...@@ -137,7 +140,7 @@ class LM(abc.ABC):
@classmethod @classmethod
def create_from_arg_string( def create_from_arg_string(
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None cls: type[T], arg_string: str, additional_config: dict | None = None
) -> T: ) -> T:
""" """
Creates an instance of the LM class using the given argument string and additional config. Creates an instance of the LM class using the given argument string and additional config.
...@@ -156,7 +159,7 @@ class LM(abc.ABC): ...@@ -156,7 +159,7 @@ class LM(abc.ABC):
@classmethod @classmethod
def create_from_arg_obj( def create_from_arg_obj(
cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None cls: type[T], arg_dict: dict, additional_config: dict | None = None
) -> T: ) -> T:
""" """
Creates an instance of the LM class using the given arg_obj Creates an instance of the LM class using the given arg_obj
...@@ -199,7 +202,7 @@ class LM(abc.ABC): ...@@ -199,7 +202,7 @@ class LM(abc.ABC):
"To use this model with chat templates, please implement the 'tokenizer_name' property." "To use this model with chat templates, please implement the 'tokenizer_name' property."
) )
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: bool | str = False) -> str | None:
"""Returns the chat template structure for user/assistant messages if a template is provided. """Returns the chat template structure for user/assistant messages if a template is provided.
This method is intended to be overridden in a subclass to define a specific chat template format. This method is intended to be overridden in a subclass to define a specific chat template format.
For models that do not support chat templates, this method returns None by default. For models that do not support chat templates, this method returns None by default.
...@@ -207,7 +210,7 @@ class LM(abc.ABC): ...@@ -207,7 +210,7 @@ class LM(abc.ABC):
return "" return ""
def set_cache_hook(self, cache_hook: "CacheHook") -> None: def set_cache_hook(self, cache_hook: CacheHook) -> None:
self.cache_hook = cache_hook self.cache_hook = cache_hook
...@@ -218,9 +221,9 @@ def hash_args(attr: str, args: Iterable[Any]) -> str: ...@@ -218,9 +221,9 @@ def hash_args(attr: str, args: Iterable[Any]) -> str:
class CacheHook: class CacheHook:
def __init__(self, cachinglm: Optional["CachingLM"]) -> None: def __init__(self, cachinglm: CachingLM | None) -> None:
if cachinglm is None: if cachinglm is None:
self.dbdict: Optional["SqliteDict"] = None self.dbdict: SqliteDict | None = None
return return
self.dbdict = cachinglm.dbdict self.dbdict = cachinglm.dbdict
...@@ -258,7 +261,7 @@ class CachingLM: ...@@ -258,7 +261,7 @@ class CachingLM:
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM") eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
return lm_attr return lm_attr
def _fn(requests: list["Instance"]) -> list["Instance"]: def _fn(requests: list[Instance]) -> list[Instance]:
res = [] res = []
remaining_reqs = [] remaining_reqs = []
warned = False warned = False
...@@ -313,7 +316,7 @@ class CachingLM: ...@@ -313,7 +316,7 @@ class CachingLM:
return _fn return _fn
def get_cache_hook(self) -> "CacheHook": def get_cache_hook(self) -> CacheHook:
return CacheHook(self) return CacheHook(self)
...@@ -337,7 +340,9 @@ class TemplateLM(LM): ...@@ -337,7 +340,9 @@ class TemplateLM(LM):
return self.eot_token_id return self.eot_token_id
@abc.abstractmethod @abc.abstractmethod
def tok_encode(self, string: str, add_special_tokens=False, **kwargs) -> list[int]: def tok_encode(
self, string: str, add_special_tokens: bool | None = None, **kwargs
) -> list[int]:
""" """
Tokenize a string using the model's tokenizer and return a list of token IDs. Tokenize a string using the model's tokenizer and return a list of token IDs.
""" """
...@@ -345,7 +350,7 @@ class TemplateLM(LM): ...@@ -345,7 +350,7 @@ class TemplateLM(LM):
@abc.abstractmethod @abc.abstractmethod
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, requests: list["Instance"], **kwargs self, requests: list[Instance], **kwargs
) -> list[tuple[float, bool]]: ) -> list[tuple[float, bool]]:
pass pass
...@@ -399,7 +404,7 @@ class TemplateLM(LM): ...@@ -399,7 +404,7 @@ class TemplateLM(LM):
return context_enc, continuation_enc return context_enc, continuation_enc
def loglikelihood( def loglikelihood(
self, requests: list["Instance"], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> list[tuple[float, bool]]: ) -> list[tuple[float, bool]]:
""" """
Compute log-likelihood of generating continuations from contexts. Compute log-likelihood of generating continuations from contexts.
...@@ -456,7 +461,7 @@ class TemplateLM(LM): ...@@ -456,7 +461,7 @@ class TemplateLM(LM):
def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
pass pass
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: bool | str = False) -> str | None:
""" """
Set and get the appropriate chat template for the model. Set and get the appropriate chat template for the model.
This method sets the tokenizer's chat_template and returns the template string for reproducibility. This method sets the tokenizer's chat_template and returns the template string for reproducibility.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment