Commit 0bad3ace authored by Baber's avatar Baber
Browse files

nit

parent 43388406
from __future__ import annotations
import abc import abc
import hashlib import hashlib
import json import json
import logging import logging
import os import os
from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, TypeVar, Union from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, TypeVar
from tqdm import tqdm from tqdm import tqdm
...@@ -31,7 +34,7 @@ class LM(abc.ABC): ...@@ -31,7 +34,7 @@ class LM(abc.ABC):
# set rank and world size to a single process, by default. # set rank and world size to a single process, by default.
self._rank = 0 self._rank = 0
self._world_size = 1 self._world_size = 1
self.cache_hook: "CacheHook" = CacheHook(None) self.cache_hook: CacheHook = CacheHook(None)
@abc.abstractmethod @abc.abstractmethod
def loglikelihood(self, requests) -> list[tuple[float, bool]]: def loglikelihood(self, requests) -> list[tuple[float, bool]]:
...@@ -101,7 +104,7 @@ class LM(abc.ABC): ...@@ -101,7 +104,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length # TODO: Add an optional max length
@abc.abstractmethod @abc.abstractmethod
def generate_until(self, requests: list["Instance"]) -> list[str]: def generate_until(self, requests: list[Instance]) -> list[str]:
"""Generate greedily until a stopping sequence """Generate greedily until a stopping sequence
:param requests: list[Instance] :param requests: list[Instance]
...@@ -137,7 +140,7 @@ class LM(abc.ABC): ...@@ -137,7 +140,7 @@ class LM(abc.ABC):
@classmethod @classmethod
def create_from_arg_string( def create_from_arg_string(
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None cls: type[T], arg_string: str, additional_config: dict | None = None
) -> T: ) -> T:
""" """
Creates an instance of the LM class using the given argument string and additional config. Creates an instance of the LM class using the given argument string and additional config.
...@@ -156,7 +159,7 @@ class LM(abc.ABC): ...@@ -156,7 +159,7 @@ class LM(abc.ABC):
@classmethod @classmethod
def create_from_arg_obj( def create_from_arg_obj(
cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None cls: type[T], arg_dict: dict, additional_config: dict | None = None
) -> T: ) -> T:
""" """
Creates an instance of the LM class using the given arg_obj Creates an instance of the LM class using the given arg_obj
...@@ -201,7 +204,7 @@ class LM(abc.ABC): ...@@ -201,7 +204,7 @@ class LM(abc.ABC):
"To use this model with chat templates, please implement the 'tokenizer_name' property." "To use this model with chat templates, please implement the 'tokenizer_name' property."
) )
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: bool | str = False) -> str | None:
"""Returns the chat template structure for user/assistant messages if a template is provided. """Returns the chat template structure for user/assistant messages if a template is provided.
This method is intended to be overridden in a subclass to define a specific chat template format. This method is intended to be overridden in a subclass to define a specific chat template format.
For models that do not support chat templates, this method returns None by default. For models that do not support chat templates, this method returns None by default.
...@@ -209,7 +212,7 @@ class LM(abc.ABC): ...@@ -209,7 +212,7 @@ class LM(abc.ABC):
return "" return ""
def set_cache_hook(self, cache_hook: "CacheHook") -> None: def set_cache_hook(self, cache_hook: CacheHook) -> None:
"""Sets the cache hook for the LM, which is used to cache responses from the LM.""" """Sets the cache hook for the LM, which is used to cache responses from the LM."""
self.cache_hook = cache_hook self.cache_hook = cache_hook
...@@ -221,10 +224,10 @@ def hash_args(attr: str, args: Iterable[Any]) -> str: ...@@ -221,10 +224,10 @@ def hash_args(attr: str, args: Iterable[Any]) -> str:
class CacheHook: class CacheHook:
def __init__(self, cachinglm: Optional["CachingLM"]) -> None: def __init__(self, cachinglm: CachingLM | None) -> None:
"""CacheHook is used to cache responses from the LM.""" """CacheHook is used to cache responses from the LM."""
if cachinglm is None: if cachinglm is None:
self.dbdict: Optional["SqliteDict"] = None self.dbdict: SqliteDict | None = None
return return
self.dbdict = cachinglm.dbdict self.dbdict = cachinglm.dbdict
...@@ -238,7 +241,7 @@ class CacheHook: ...@@ -238,7 +241,7 @@ class CacheHook:
class CachingLM: class CachingLM:
def __init__(self, lm: "LM", cache_db: str) -> None: def __init__(self, lm: LM, cache_db: str) -> None:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not. """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM :param lm: LM
...@@ -263,7 +266,7 @@ class CachingLM: ...@@ -263,7 +266,7 @@ class CachingLM:
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM") eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
return lm_attr return lm_attr
def _fn(requests: list["Instance"]) -> list["Instance"]: def _fn(requests: list[Instance]) -> list[Instance]:
res = [] res = []
remaining_reqs = [] remaining_reqs = []
warned = False warned = False
...@@ -295,11 +298,8 @@ class CachingLM: ...@@ -295,11 +298,8 @@ class CachingLM:
eval_logger.info( eval_logger.info(
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}" f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
) )
if remaining_reqs:
# actually run the LM on the requests that do not have cached results rem_res = getattr(self.lm, attr)(remaining_reqs) if remaining_reqs else []
rem_res = getattr(self.lm, attr)(remaining_reqs)
else:
rem_res = []
# stick the new ones back into the list and also cache any of the new ones # stick the new ones back into the list and also cache any of the new ones
resptr = 0 resptr = 0
...@@ -318,7 +318,7 @@ class CachingLM: ...@@ -318,7 +318,7 @@ class CachingLM:
return _fn return _fn
def get_cache_hook(self) -> "CacheHook": def get_cache_hook(self) -> CacheHook:
return CacheHook(self) return CacheHook(self)
...@@ -395,7 +395,7 @@ class TemplateLM(LM): ...@@ -395,7 +395,7 @@ class TemplateLM(LM):
return context_enc, continuation_enc return context_enc, continuation_enc
def loglikelihood( def loglikelihood(
self, requests: list["Instance"], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> list[tuple[float, bool]]: ) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context. """Compute log-likelihood of generating a continuation from a context.
...@@ -428,7 +428,7 @@ class TemplateLM(LM): ...@@ -428,7 +428,7 @@ class TemplateLM(LM):
@abc.abstractmethod @abc.abstractmethod
def generate_until( def generate_until(
self, requests: list["Instance"], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> list[str]: ) -> list[str]:
"""Generate until a stopping sequence. """Generate until a stopping sequence.
...@@ -449,7 +449,7 @@ class TemplateLM(LM): ...@@ -449,7 +449,7 @@ class TemplateLM(LM):
""" """
pass pass
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: bool | str = False) -> str | None:
""" """
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str) Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Set and get the appropriate chat template for the model. Set and get the appropriate chat template for the model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment