Commit fc69d84f authored by Ethan Smith's avatar Ethan Smith
Browse files

Add suggestions from autotyping

This adds a bunch of simple annotations suggested by https://github.com/JelleZijlstra/autotyping.
parent da85f290
......@@ -13,7 +13,7 @@ class Filter:
"""
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
......@@ -40,8 +40,7 @@ class FilterEnsemble:
name: str
filters: List[Filter]
def apply(self, instances: List[Instance]):
def apply(self, instances: List[Instance]) -> None:
resps = [
inst.resps for inst in instances
] # operate just on the model responses
......
......@@ -19,7 +19,7 @@ class Instance:
doc_id: str = None
repeats: str = None
def __post_init__(self):
def __post_init__(self) -> None:
# unpack metadata field
self.task_name, self.doc_id, self.repeats = self.metadata
......
......@@ -302,7 +302,7 @@ def _sacreformat(refs, preds):
class _bootstrap_internal:
def __init__(self, f, n):
def __init__(self, f, n) -> None:
self.f = f
self.n = n
......
......@@ -13,7 +13,7 @@ from lm_eval.logger import eval_logger
class LM(abc.ABC):
def __init__(self):
def __init__(self) -> None:
"""Defines the interface that should be implemented by all LM subclasses.
LMs are assumed to take text (strings) as input and yield strings as output
(inputs/outputs should be tokenization-agnostic.)
......@@ -133,7 +133,7 @@ class LM(abc.ABC):
# not support multi-device parallelism nor expect it.
return self._world_size
def set_cache_hook(self, cache_hook):
def set_cache_hook(self, cache_hook) -> None:
self.cache_hook = cache_hook
......@@ -144,14 +144,14 @@ def hash_args(attr, args):
class CacheHook:
def __init__(self, cachinglm):
def __init__(self, cachinglm) -> None:
if cachinglm is None:
self.dbdict = None
return
self.dbdict = cachinglm.dbdict
def add_partial(self, attr, req, res):
def add_partial(self, attr, req, res) -> None:
if self.dbdict is None:
return
hsh = hash_args(attr, req)
......@@ -159,7 +159,7 @@ class CacheHook:
class CachingLM:
def __init__(self, lm, cache_db):
def __init__(self, lm, cache_db) -> None:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
......
class Sampler:
def __init__(self, docs, task, fewshot_indices=None, rnd=None):
def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
self.rnd = rnd
assert self.rnd, "must pass rnd to FewShotSampler!"
......@@ -19,7 +18,6 @@ class Sampler:
self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc, num_fewshot):
# draw an extra fewshot sample if using same split as evaluating on
n_samples = (
num_fewshot + 1
......@@ -74,7 +72,7 @@ class Sampler:
class BalancedSampler(Sampler):
def sample(self, n):
def sample(self, n) -> None:
"""
TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random?
......@@ -84,7 +82,7 @@ class BalancedSampler(Sampler):
class ManualSampler(Sampler):
def sample(self, n):
def sample(self, n) -> None:
""" """
pass
......
......@@ -88,8 +88,7 @@ class TaskConfig(dict):
metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self):
def __post_init__(self) -> None:
if self.generation_kwargs is not None:
if self.output_type != "greedy_until":
eval_logger.warning(
......@@ -171,7 +170,7 @@ class Task(abc.ABC):
cache_dir=None,
download_mode=None,
config=None,
):
) -> None:
"""
:param data_dir: str
Stores the path to a local folder containing the `Task`'s data files.
......@@ -213,7 +212,7 @@ class Task(abc.ABC):
list(self.fewshot_docs()), self, rnd=random.Random(1234)
)
def download(self, data_dir=None, cache_dir=None, download_mode=None):
def download(self, data_dir=None, cache_dir=None, download_mode=None) -> None:
"""Downloads and returns the task dataset.
Override this method to download the dataset from a custom API.
......@@ -322,7 +321,7 @@ class Task(abc.ABC):
return rnd.sample(self._training_docs, k)
def doc_to_decontamination_query(self, doc):
def doc_to_decontamination_query(self, doc) -> None:
print(
"Override doc_to_decontamination_query with document specific decontamination query."
)
......@@ -336,7 +335,7 @@ class Task(abc.ABC):
def doc_to_target(self, doc):
pass
def build_all_requests(self, limit=None, rank=None, world_size=None):
def build_all_requests(self, limit=None, rank=None, world_size=None) -> None:
"""Build a set of Instances for a task, and store them in task.instances"""
if self.has_test_docs():
docs = self.test_docs()
......@@ -472,7 +471,6 @@ class Task(abc.ABC):
return labeled_examples + str(example)
def apply_filters(self):
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances)
......@@ -498,7 +496,7 @@ class ConfigurableTask(Task):
def __init__(
self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None
): # TODO no super() call here
) -> None: # TODO no super() call here
# Get pre-configured attributes
self._config = self.CONFIG
......@@ -570,7 +568,6 @@ class ConfigurableTask(Task):
"aggregation"
]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = get_default_aggregation(metric_name)
eval_logger.warning(
......@@ -683,8 +680,7 @@ class ConfigurableTask(Task):
f'Both target_delimiter and target choice: "{choice}" does not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
)
def download(self, dataset_kwargs=None):
def download(self, dataset_kwargs=None) -> None:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
......@@ -767,7 +763,6 @@ class ConfigurableTask(Task):
return doc
def doc_to_text(self, doc):
if self.prompt is not None:
doc_to_text = self.prompt
else:
......@@ -802,7 +797,6 @@ class ConfigurableTask(Task):
raise TypeError
def doc_to_target(self, doc: dict) -> Union[int, str, list]:
if self.prompt is not None:
doc_to_target = self.prompt
else:
......@@ -844,7 +838,6 @@ class ConfigurableTask(Task):
raise TypeError
def doc_to_choice(self, doc: Any) -> List[str]:
if self.prompt is not None:
doc_to_choice = self.prompt
elif self._config.doc_to_choice is None:
......@@ -888,13 +881,11 @@ class ConfigurableTask(Task):
def construct_requests(
self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]:
if self.OUTPUT_TYPE == "loglikelihood":
arguments = (ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice":
choices = self.doc_to_choice(doc)
target_delimiter = self._config.target_delimiter
if self.multiple_input:
......@@ -945,7 +936,6 @@ class ConfigurableTask(Task):
)
def process_results(self, doc, results):
if callable(self._config.process_results):
return self._config.process_results(doc, results)
......@@ -980,7 +970,6 @@ class ConfigurableTask(Task):
),
}
elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results)
# retrieve choices in List[str] form, to compute choice lengths, etc.
......@@ -1034,7 +1023,6 @@ class ConfigurableTask(Task):
result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "greedy_until":
gold = self.doc_to_target(doc)
if self._config.doc_to_choice is not None:
# If you set doc_to_choice,
......@@ -1164,7 +1152,7 @@ class PerplexityTask(Task):
def doc_to_decontamination_query(self, doc):
return doc
def doc_to_text(self, doc):
def doc_to_text(self, doc) -> str:
return ""
def doc_to_target(self, doc):
......
......@@ -11,8 +11,7 @@ from lm_eval.api.registry import (
)
def include_benchmarks(task_dir):
def include_benchmarks(task_dir: str) -> None:
for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
for f in file_list:
......
import os
from typing import Any
import zstandard
import json
import jsonlines
......@@ -9,7 +10,7 @@ import tqdm
from pathlib import Path
def json_serial(obj):
def json_serial(obj: Any) -> str:
"""JSON serializer for objects not serializable by default json code"""
if isinstance(obj, (datetime.datetime,)):
......@@ -19,7 +20,7 @@ def json_serial(obj):
# Modified version of lm_dataformat Archive for single file.
class Archive:
def __init__(self, file_path, compression_level=3):
def __init__(self, file_path: str, compression_level: int = 3) -> None:
self.file_path = file_path
dir_name = os.path.dirname(file_path)
if dir_name:
......@@ -28,7 +29,7 @@ class Archive:
self.cctx = zstandard.ZstdCompressor(level=compression_level)
self.compressor = self.cctx.stream_writer(self.fh)
def add_data(self, data, meta={}):
def add_data(self, data, meta={}) -> None:
self.compressor.write(
json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
"UTF-8"
......@@ -36,7 +37,7 @@ class Archive:
+ b"\n"
)
def commit(self):
def commit(self) -> None:
self.compressor.flush(zstandard.FLUSH_FRAME)
self.fh.flush()
self.fh.close()
......@@ -44,10 +45,16 @@ class Archive:
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader:
def __init__(self):
def __init__(self) -> None:
pass
def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner="\n\n"):
def read(
self,
file,
get_meta: bool = False,
autojoin_paragraphs: bool = True,
para_joiner: str = "\n\n",
):
with open(file, "rb") as fh:
self.fh = fh
cctx = zstandard.ZstdDecompressor()
......@@ -72,7 +79,7 @@ class Reader:
class TextArchive:
def __init__(self, file_path, mode="rb+"):
def __init__(self, file_path, mode: str = "rb+") -> None:
self.file_path = file_path
dir_name = os.path.dirname(file_path)
if dir_name:
......@@ -83,21 +90,21 @@ class TextArchive:
self.fh = open(self.file_path, mode)
def add_data(self, data):
def add_data(self, data) -> None:
self.fh.write(data.encode("UTF-8") + b"\n")
def commit(self):
def commit(self) -> None:
self.fh.flush()
self.fh.close()
class TextReader:
def __init__(self, file_path):
def __init__(self, file_path) -> None:
self.file_path = file_path
# Optimized mmap read with infrequent tqdm updates to maintain speed
# Tested up to 250MB/s.
def read_tqdm(self, update_frequency=10000):
def read_tqdm(self, update_frequency: int = 10000):
current_file_position = 0
line_counter = 0
with open(self.file_path, "r") as fh, tqdm.tqdm(
......@@ -149,7 +156,7 @@ class TextReader:
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
def __init__(self, file):
def __init__(self, file) -> None:
self.file = file
def read_tqdm(self):
......
......@@ -11,7 +11,7 @@ from .archiver import ZStdTextReader
# Was used for testing the evaluator decoupled from the full logic below
def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
def get_train_overlap_stub(docs: dict, ngrams_path: str, ngrams_n_size: str):
simulated_overlap = 0.1
contaminated = int(len(docs) * simulated_overlap)
return random.sample(range(len(docs)), contaminated)
......@@ -25,6 +25,7 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
# files. These should exist in the "ngrams_path" provided to this function.
# Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)}
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
......@@ -33,7 +34,7 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
# 4. Strip the task_set from the dictionary keys and return
#
# We cache the task+set lookups as well as the overlaps.
def get_train_overlap(docs_by_task_set, ngrams_path, limit):
def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> dict:
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
info_dict_path = os.path.join(ngrams_path, "info.json")
......@@ -46,7 +47,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
print("Building Lookups...")
start = time.perf_counter()
def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit):
def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) -> str:
return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"
lookups = {}
......
import re
import string
import timeit
import pickle
import traceback
from pprint import pprint
from typing import Iterator, Sequence, TypeVar
# This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
......@@ -16,10 +16,11 @@ except Exception:
traceback.print_exc()
JANITOR_CPP = False
T = TypeVar("T")
# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def form_ngrams(sequence, n):
def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[tuple[T, ...]]:
history = []
while n > 1:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
......@@ -36,7 +37,7 @@ def form_ngrams(sequence, n):
del history[0]
def word_ngrams(s, n):
def word_ngrams(s: str, n: int) -> Iterator[str]:
"""Splits a string into ngram words"""
tokens = s.split() # not a generator :(
ngram_seqs = form_ngrams(iter(tokens), n)
......@@ -68,14 +69,14 @@ def word_ngrams(s, n):
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def split_indices(s):
def split_indices(s: str) -> Iterator[tuple[str, tuple[int, int]]]:
"""Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...)
"""
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
def word_ngrams_indices(s, n):
def word_ngrams_indices(s: str, n: int) -> Iterator[tuple[str, tuple[int, int]]]:
"""Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices = split_indices(s)
......@@ -104,16 +105,15 @@ def word_ngrams_indices(s, n):
class Janitor:
# FIXME delete_chars: Should anything else go here? Special chars?
def __init__(
self,
ngram_n=13,
window_to_remove=200,
too_dirty_cutoff=10,
minimum_slice_length=200,
delete_chars=string.punctuation,
):
ngram_n: int = 13,
window_to_remove: int = 200,
too_dirty_cutoff: int = 10,
minimum_slice_length: int = 200,
delete_chars: str = string.punctuation,
) -> None:
self.ngram_n = ngram_n
self.window_to_remove = window_to_remove
self.too_dirty_cutoff = too_dirty_cutoff
......@@ -135,11 +135,11 @@ class Janitor:
# I/O for saving contamination ngrams
##############
def save_contamination_ngrams(self, filename):
def save_contamination_ngrams(self, filename: str) -> None:
with open(filename, "wb") as fp:
pickle.dump(filename, fp)
def load_contamination_ngrams(self, filename):
def load_contamination_ngrams(self, filename: str) -> None:
with open(filename, "rb") as fp:
self.dirt_ngrams = pickle.load(fp)
......@@ -147,7 +147,7 @@ class Janitor:
# Call these :)
##############
def register_contaminant(self, dirt_string):
def register_contaminant(self, dirt_string: str) -> None:
"""Register a string as contamination to be removed, e.g. a test set
This breaks the dirt_string into ngrams to store for future cleaning"""
if JANITOR_CPP:
......@@ -156,7 +156,7 @@ class Janitor:
print("WARNING: Janitor running in python mode")
return self.register_contaminant_python(dirt_string)
def clean(self, dirty_string):
def clean(self, dirty_string: str) -> list[str]:
"""Clean a string (e.g. a training set) by removing all ngrams previously
registered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty"""
......@@ -166,7 +166,7 @@ class Janitor:
print("WARNING: Janitor running in python mode")
return self.clean_python(dirty_string)
def _split_chunks(self, dirty_string, dirty_parts):
def _split_chunks(self, dirty_string: str, dirty_parts: Sequence[tuple]) -> list[str]:
clean_chunks = []
splice_idx = 0
end = -1
......@@ -189,12 +189,12 @@ class Janitor:
# Fast C++
##############
def register_contaminant_cpp(self, dirt_string):
def register_contaminant_cpp(self, dirt_string) -> None:
self.dirt_ngrams.update(
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
)
def clean_cpp(self, dirty_string):
def clean_cpp(self, dirty_string: str) -> list[str]:
contamination_indices = janitor_util.clean_ngram_with_indices(
dirty_string, self.delete_chars, self.ngram_n
)
......@@ -204,15 +204,15 @@ class Janitor:
# Slow python
##############
def normalize_string(self, s):
def normalize_string(self, s: str) -> str:
return s.translate(self.translation_table)
def register_contaminant_python(self, dirt_string):
def register_contaminant_python(self, dirt_string: str) -> None:
self.dirt_ngrams.update(
word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
)
def clean_python(self, dirty_string):
def clean_python(self, dirty_string: str) -> list[str]:
contamination_indices = (
(None, *idx_pair)
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
......
......@@ -42,11 +42,11 @@ def simple_evaluate(
device=None,
use_cache=None,
limit=None,
bootstrap_iters=100000,
check_integrity=False,
bootstrap_iters: int = 100000,
check_integrity: bool = False,
decontamination_ngrams_path=None,
write_out=False,
log_samples=True,
write_out: bool = False,
log_samples: bool = True,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -117,7 +117,6 @@ def simple_evaluate(
task_dict = lm_eval.tasks.get_task_dict(tasks)
for task_name in task_dict.keys():
task_obj = task_dict[task_name]
if type(task_obj) == tuple:
group, task_obj = task_obj
......@@ -175,10 +174,10 @@ def evaluate(
lm,
task_dict,
limit=None,
bootstrap_iters=100000,
bootstrap_iters: int = 100000,
decontamination_ngrams_path=None,
write_out=False,
log_samples=True,
write_out: bool = False,
log_samples: bool = True,
):
"""Instantiate and evaluate a model on a list of tasks.
......@@ -224,7 +223,6 @@ def evaluate(
# get lists of each type of request
for task_name, task in task_dict.items():
if type(task) == tuple:
group, task = task
task_groups[task_name] = group
......@@ -349,7 +347,6 @@ def evaluate(
# if multigpu, then gather data across all ranks
# first gather logged samples across all ranks
for task_name, task_samples in list(samples.items()):
full_samples = [None] * lm.world_size
torch.distributed.all_gather_object(full_samples, task_samples)
......@@ -358,7 +355,6 @@ def evaluate(
# then collect metrics across all ranks
vals_torch = collections.defaultdict(list)
for (task_name, key, metric), items in vals.items():
numitem = 0
if type(items[0]) == tuple:
numitem = len(items[0])
......
......@@ -9,7 +9,7 @@ class DecontaminationFilter(Filter):
name = "track_decontamination"
def __init__(self, path):
def __init__(self, path) -> None:
"""
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
......@@ -17,7 +17,7 @@ class DecontaminationFilter(Filter):
"""
self._decontam_results = None
def apply(self, reps):
def apply(self, reps) -> None:
"""
Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
"""
......
......@@ -6,7 +6,9 @@ from lm_eval.api.filter import Filter
class RegexFilter(Filter):
""" """
def __init__(self, regex_pattern=r"#### (\-?[0-9\.\,]+)", fallback="[invalid]"):
def __init__(
self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", fallback: str = "[invalid]"
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
......@@ -41,12 +43,11 @@ class RegexFilter(Filter):
class WhitespaceFilter(Filter):
""" """
def __init__(self):
def __init__(self) -> None:
pass
def apply(self, resps):
def filter_set(inst):
filtered_resp = []
for resp in inst:
if resp.startswith(" "):
......
......@@ -4,7 +4,7 @@ from lm_eval.api.filter import Filter
class TakeFirstFilter(Filter):
def __init__(self):
def __init__(self) -> None:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
......@@ -17,8 +17,7 @@ class TakeFirstFilter(Filter):
class TakeKFilter(Filter):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
self.k = kwargs.pop("k")
super().__init__(*args, **kwargs)
......@@ -32,7 +31,7 @@ class TakeKFilter(Filter):
class MajorityVoteFilter(Filter):
def __init__(self):
def __init__(self) -> None:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
......
......@@ -76,7 +76,7 @@ class AnthropicLM(LM):
max_tokens_to_sample: int = 256,
temperature: float = 0, # defaults to 1
**kwargs, # top_p, top_k, etc.
):
) -> None:
"""Anthropic API wrapper.
:param model: str
......@@ -135,11 +135,10 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.")
def greedy_until(self, requests) -> List[str]:
if not requests:
return []
......
......@@ -5,7 +5,7 @@ from lm_eval.api.registry import register_model
@register_model("dummy")
class DummyLM(LM):
def __init__(self):
def __init__(self) -> None:
super().__init__()
@classmethod
......
......@@ -90,7 +90,7 @@ class HFLM(LM):
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
gptq: Optional[Union[bool, str]] = False,
gptq_use_triton: Optional[bool] = False,
):
) -> None:
super().__init__()
assert isinstance(device, str)
......@@ -334,7 +334,7 @@ class HFLM(LM):
return self._DEFAULT_MAX_LENGTH
@property
def max_gen_toks(self):
def max_gen_toks(self) -> int:
return 256
@property
......@@ -353,7 +353,7 @@ class HFLM(LM):
def world_size(self):
return self._world_size
def _detect_batch_size(self, requests=None, pos=0):
def _detect_batch_size(self, requests=None, pos: int = 0):
if requests:
_, context_enc, continuation_enc = requests[pos]
max_length = len(
......@@ -419,7 +419,7 @@ class HFLM(LM):
return encoding
def tok_batch_encode(
self, strings: List[str], padding_side="left", left_truncate_len=None
self, strings: List[str], padding_side: str = "left", left_truncate_len=None
):
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side
......@@ -595,7 +595,9 @@ class HFLM(LM):
return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False, override_bs=None):
def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False, override_bs=None
):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
......
......@@ -69,7 +69,7 @@ class OpenaiCompletionsLM(LM):
engine: str = "text-davinci-003",
truncate: bool = False,
batch_size: int = 1,
):
) -> None:
"""
:param engine: str
......@@ -99,12 +99,12 @@ class OpenaiCompletionsLM(LM):
return self.end_of_text_token_id
@property
def max_length(self):
def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
return 2048
@property
def max_gen_toks(self):
def max_gen_toks(self) -> int:
return 256
@property
......@@ -152,7 +152,7 @@ class OpenaiCompletionsLM(LM):
return self._loglikelihood_tokens(new_reqs)
def _loglikelihood_tokens(
self, requests, disable_tqdm=False
self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]:
res = []
......
......@@ -41,7 +41,7 @@ def textsynth_completion(**kwargs):
@register_model("textsynth")
class TextSynthLM(LM):
def __init__(self, engine, truncate=False):
def __init__(self, engine, truncate: bool = False) -> None:
"""
:param engine: str
TextSynth API engine (e.g. `gptj_6B`)
......@@ -62,12 +62,12 @@ class TextSynthLM(LM):
raise NotImplementedError()
@property
def max_length(self):
def max_length(self) -> int:
# NOTE: Turn on truncation to avoid errors on long inputs.
return 2048
@property
def max_gen_toks(self):
def max_gen_toks(self) -> int:
return 256
@property
......
......@@ -5,7 +5,7 @@ from lm_eval.logger import eval_logger
# Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name.
# This allows us to access prompts
PROMPT_REGISTRY = {
PROMPT_REGISTRY: dict[str, dict[str, str]] = {
"qa-basic": {
"question-newline-answer": "Question: {{question}}\nAnswer:",
"q-newline-a": "Q: {{question}}\nA:",
......@@ -13,7 +13,7 @@ PROMPT_REGISTRY = {
}
def get_prompt(prompt_id: str, dataset_name=None, subset_name=None):
def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None):
# unpack prompt name
category_name, prompt_name = prompt_id.split(":")
if subset_name is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment