Commit 0bad3ace authored by Baber's avatar Baber
Browse files

nit

parent 43388406
from __future__ import annotations
import abc
import hashlib
import json
import logging
import os
from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, TypeVar, Union
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, TypeVar
from tqdm import tqdm
......@@ -31,7 +34,7 @@ class LM(abc.ABC):
# set rank and world size to a single process, by default.
self._rank = 0
self._world_size = 1
self.cache_hook: "CacheHook" = CacheHook(None)
self.cache_hook: CacheHook = CacheHook(None)
@abc.abstractmethod
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
......@@ -101,7 +104,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length
@abc.abstractmethod
def generate_until(self, requests: list["Instance"]) -> list[str]:
def generate_until(self, requests: list[Instance]) -> list[str]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
......@@ -137,7 +140,7 @@ class LM(abc.ABC):
@classmethod
def create_from_arg_string(
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
cls: type[T], arg_string: str, additional_config: dict | None = None
) -> T:
"""
Creates an instance of the LM class using the given argument string and additional config.
......@@ -156,7 +159,7 @@ class LM(abc.ABC):
@classmethod
def create_from_arg_obj(
cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
cls: type[T], arg_dict: dict, additional_config: dict | None = None
) -> T:
"""
Creates an instance of the LM class using the given arg_obj
......@@ -201,7 +204,7 @@ class LM(abc.ABC):
"To use this model with chat templates, please implement the 'tokenizer_name' property."
)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
def chat_template(self, chat_template: bool | str = False) -> str | None:
"""Returns the chat template structure for user/assistant messages if a template is provided.
This method is intended to be overridden in a subclass to define a specific chat template format.
For models that do not support chat templates, this method returns None by default.
......@@ -209,7 +212,7 @@ class LM(abc.ABC):
return ""
def set_cache_hook(self, cache_hook: "CacheHook") -> None:
def set_cache_hook(self, cache_hook: CacheHook) -> None:
"""Sets the cache hook for the LM, which is used to cache responses from the LM."""
self.cache_hook = cache_hook
......@@ -221,10 +224,10 @@ def hash_args(attr: str, args: Iterable[Any]) -> str:
class CacheHook:
def __init__(self, cachinglm: Optional["CachingLM"]) -> None:
def __init__(self, cachinglm: CachingLM | None) -> None:
"""CacheHook is used to cache responses from the LM."""
if cachinglm is None:
self.dbdict: Optional["SqliteDict"] = None
self.dbdict: SqliteDict | None = None
return
self.dbdict = cachinglm.dbdict
......@@ -238,7 +241,7 @@ class CacheHook:
class CachingLM:
def __init__(self, lm: "LM", cache_db: str) -> None:
def __init__(self, lm: LM, cache_db: str) -> None:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
......@@ -263,7 +266,7 @@ class CachingLM:
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
return lm_attr
def _fn(requests: list["Instance"]) -> list["Instance"]:
def _fn(requests: list[Instance]) -> list[Instance]:
res = []
remaining_reqs = []
warned = False
......@@ -295,11 +298,8 @@ class CachingLM:
eval_logger.info(
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
)
if remaining_reqs:
# actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs)
else:
rem_res = []
rem_res = getattr(self.lm, attr)(remaining_reqs) if remaining_reqs else []
# stick the new ones back into the list and also cache any of the new ones
resptr = 0
......@@ -318,7 +318,7 @@ class CachingLM:
return _fn
def get_cache_hook(self) -> "CacheHook":
def get_cache_hook(self) -> CacheHook:
return CacheHook(self)
......@@ -395,7 +395,7 @@ class TemplateLM(LM):
return context_enc, continuation_enc
def loglikelihood(
self, requests: list["Instance"], disable_tqdm: bool = False
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
......@@ -428,7 +428,7 @@ class TemplateLM(LM):
@abc.abstractmethod
def generate_until(
self, requests: list["Instance"], disable_tqdm: bool = False
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[str]:
"""Generate until a stopping sequence.
......@@ -449,7 +449,7 @@ class TemplateLM(LM):
"""
pass
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
def chat_template(self, chat_template: bool | str = False) -> str | None:
"""
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Set and get the appropriate chat template for the model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment