Commit 1e9c8b59 authored by baberabb's avatar baberabb
Browse files

add typehints

parent b22f3440
......@@ -16,13 +16,14 @@ from pathlib import Path
import torch.nn.functional as F
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator, find_executable_batch_size, DistributedType
from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple
eval_logger = utils.eval_logger
......@@ -420,7 +421,9 @@ class HFLM(LM):
utils.clear_torch_cache()
return batch_size
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None):
def tok_encode(
self, string: str, left_truncate_len=None, add_special_tokens=None
) -> List[int]:
""" """
if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
......@@ -442,7 +445,7 @@ class HFLM(LM):
padding_side: str = "left",
left_truncate_len: int = None,
truncation: bool = False,
):
) -> Tuple[List[int], List[int]]:
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side
......@@ -536,7 +539,9 @@ class HFLM(LM):
return logits
def _encode_pair(self, context, continuation):
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
......@@ -551,7 +556,7 @@ class HFLM(LM):
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests):
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
......@@ -566,7 +571,7 @@ class HFLM(LM):
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests):
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
loglikelihoods = []
adaptive_batch_size = None
......@@ -640,8 +645,11 @@ class HFLM(LM):
return self.batch_sizes[sched]
def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False, override_bs=None
):
self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
disable_tqdm: bool = False,
override_bs: int = None,
) -> List[Tuple[float, bool]]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
......@@ -820,7 +828,7 @@ class HFLM(LM):
return re_ord.get_original(res)
def generate_until(self, requests):
def generate_until(self, requests: List[Instance]) -> List[str]:
res = defaultdict(list)
re_ords = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment