Commit 1e9c8b59 authored by baberabb's avatar baberabb
Browse files

add typehints

parent b22f3440
...@@ -16,13 +16,14 @@ from pathlib import Path ...@@ -16,13 +16,14 @@ from pathlib import Path
import torch.nn.functional as F import torch.nn.functional as F
from lm_eval import utils from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator, find_executable_batch_size, DistributedType from accelerate import Accelerator, find_executable_batch_size, DistributedType
from typing import List, Optional, Union from typing import List, Optional, Union, Tuple
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
...@@ -420,7 +421,9 @@ class HFLM(LM): ...@@ -420,7 +421,9 @@ class HFLM(LM):
utils.clear_torch_cache() utils.clear_torch_cache()
return batch_size return batch_size
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None): def tok_encode(
self, string: str, left_truncate_len=None, add_special_tokens=None
) -> List[int]:
""" """ """ """
if add_special_tokens is None: if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
...@@ -442,7 +445,7 @@ class HFLM(LM): ...@@ -442,7 +445,7 @@ class HFLM(LM):
padding_side: str = "left", padding_side: str = "left",
left_truncate_len: int = None, left_truncate_len: int = None,
truncation: bool = False, truncation: bool = False,
): ) -> Tuple[List[int], List[int]]:
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side self.tokenizer.padding_side = padding_side
...@@ -536,7 +539,9 @@ class HFLM(LM): ...@@ -536,7 +539,9 @@ class HFLM(LM):
return logits return logits
def _encode_pair(self, context, continuation): def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip()) n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0: if n_spaces > 0:
continuation = context[-n_spaces:] + continuation continuation = context[-n_spaces:] + continuation
...@@ -551,7 +556,7 @@ class HFLM(LM): ...@@ -551,7 +556,7 @@ class HFLM(LM):
continuation_enc = whole_enc[context_enc_len:] continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc return context_enc, continuation_enc
def loglikelihood(self, requests): def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
new_reqs = [] new_reqs = []
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
...@@ -566,7 +571,7 @@ class HFLM(LM): ...@@ -566,7 +571,7 @@ class HFLM(LM):
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
loglikelihoods = [] loglikelihoods = []
adaptive_batch_size = None adaptive_batch_size = None
...@@ -640,8 +645,11 @@ class HFLM(LM): ...@@ -640,8 +645,11 @@ class HFLM(LM):
return self.batch_sizes[sched] return self.batch_sizes[sched]
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False, override_bs=None self,
): requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
disable_tqdm: bool = False,
override_bs: int = None,
) -> List[Tuple[float, bool]]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
...@@ -820,7 +828,7 @@ class HFLM(LM): ...@@ -820,7 +828,7 @@ class HFLM(LM):
return re_ord.get_original(res) return re_ord.get_original(res)
def generate_until(self, requests): def generate_until(self, requests: List[Instance]) -> List[str]:
res = defaultdict(list) res = defaultdict(list)
re_ords = {} re_ords = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment