Unverified Commit 8f448eed authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #809 from ethanhs/mypy

Add mypy baseline config
parents cc7828dd 4721379e
...@@ -43,3 +43,9 @@ repos: ...@@ -43,3 +43,9 @@ repos:
.*\.json|ignore.txt .*\.json|ignore.txt
)$ )$
args: [--check-filenames, --check-hidden, --ignore-words=ignore.txt] args: [--check-filenames, --check-hidden, --ignore-words=ignore.txt]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
hooks:
- id: mypy
additional_dependencies: [".[sentencepiece,multilingual,promptsource,gptq]", "types-PyYAML", "types-requests"]
exclude: ^tests/.*$
...@@ -14,7 +14,7 @@ class Filter: ...@@ -14,7 +14,7 @@ class Filter:
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs) -> None:
""" """
Can define custom behavior here, if an individual instantiation of a Filter class should have state. Can define custom behavior here, if an individual instantiation of a Filter class should have state.
""" """
...@@ -41,7 +41,7 @@ class FilterEnsemble: ...@@ -41,7 +41,7 @@ class FilterEnsemble:
name: str name: str
filters: List[Filter] filters: List[Filter]
def apply(self, instances: List[Instance], docs: List[Dataset]): def apply(self, instances: List[Instance], docs: List[Dataset]) -> None:
resps = [ resps = [
inst.resps for inst in instances inst.resps for inst in instances
......
...@@ -19,7 +19,7 @@ class Instance: ...@@ -19,7 +19,7 @@ class Instance:
doc_id: str = None doc_id: str = None
repeats: str = None repeats: str = None
def __post_init__(self): def __post_init__(self) -> None:
# unpack metadata field # unpack metadata field
self.task_name, self.doc_id, self.repeats = self.metadata self.task_name, self.doc_id, self.repeats = self.metadata
......
...@@ -302,7 +302,7 @@ def _sacreformat(refs, preds): ...@@ -302,7 +302,7 @@ def _sacreformat(refs, preds):
class _bootstrap_internal: class _bootstrap_internal:
def __init__(self, f, n): def __init__(self, f, n) -> None:
self.f = f self.f = f
self.n = n self.n = n
......
...@@ -13,7 +13,7 @@ from lm_eval.logger import eval_logger ...@@ -13,7 +13,7 @@ from lm_eval.logger import eval_logger
class LM(abc.ABC): class LM(abc.ABC):
def __init__(self): def __init__(self) -> None:
"""Defines the interface that should be implemented by all LM subclasses. """Defines the interface that should be implemented by all LM subclasses.
LMs are assumed to take text (strings) as input and yield strings as output LMs are assumed to take text (strings) as input and yield strings as output
(inputs/outputs should be tokenization-agnostic.) (inputs/outputs should be tokenization-agnostic.)
...@@ -133,7 +133,7 @@ class LM(abc.ABC): ...@@ -133,7 +133,7 @@ class LM(abc.ABC):
# not support multi-device parallelism nor expect it. # not support multi-device parallelism nor expect it.
return self._world_size return self._world_size
def set_cache_hook(self, cache_hook): def set_cache_hook(self, cache_hook) -> None:
self.cache_hook = cache_hook self.cache_hook = cache_hook
...@@ -144,14 +144,14 @@ def hash_args(attr, args): ...@@ -144,14 +144,14 @@ def hash_args(attr, args):
class CacheHook: class CacheHook:
def __init__(self, cachinglm): def __init__(self, cachinglm) -> None:
if cachinglm is None: if cachinglm is None:
self.dbdict = None self.dbdict = None
return return
self.dbdict = cachinglm.dbdict self.dbdict = cachinglm.dbdict
def add_partial(self, attr, req, res): def add_partial(self, attr, req, res) -> None:
if self.dbdict is None: if self.dbdict is None:
return return
hsh = hash_args(attr, req) hsh = hash_args(attr, req)
...@@ -159,7 +159,7 @@ class CacheHook: ...@@ -159,7 +159,7 @@ class CacheHook:
class CachingLM: class CachingLM:
def __init__(self, lm, cache_db): def __init__(self, lm, cache_db) -> None:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not. """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM :param lm: LM
......
class Sampler: class Sampler:
def __init__(self, docs, task, fewshot_indices=None, rnd=None): def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
self.rnd = rnd self.rnd = rnd
assert self.rnd, "must pass rnd to FewShotSampler!" assert self.rnd, "must pass rnd to FewShotSampler!"
...@@ -19,7 +18,6 @@ class Sampler: ...@@ -19,7 +18,6 @@ class Sampler:
self.docs = self.docs.select(fewshot_indices) self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc, num_fewshot): def get_context(self, doc, num_fewshot):
# draw an extra fewshot sample if using same split as evaluating on # draw an extra fewshot sample if using same split as evaluating on
n_samples = ( n_samples = (
num_fewshot + 1 num_fewshot + 1
...@@ -74,7 +72,7 @@ class Sampler: ...@@ -74,7 +72,7 @@ class Sampler:
class BalancedSampler(Sampler): class BalancedSampler(Sampler):
def sample(self, n): def sample(self, n) -> None:
""" """
TODO: this should return approximately class-balanced samples from our fewshot examples. TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random? TODO: what order should they be in? maybe random?
...@@ -84,7 +82,7 @@ class BalancedSampler(Sampler): ...@@ -84,7 +82,7 @@ class BalancedSampler(Sampler):
class ManualSampler(Sampler): class ManualSampler(Sampler):
def sample(self, n): def sample(self, n) -> None:
""" """ """ """
pass pass
......
...@@ -88,8 +88,8 @@ class TaskConfig(dict): ...@@ -88,8 +88,8 @@ class TaskConfig(dict):
metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self):
def __post_init__(self) -> None:
if "." in self.dataset_path: if "." in self.dataset_path:
import inspect import inspect
from importlib import import_module from importlib import import_module
...@@ -177,7 +177,7 @@ class Task(abc.ABC): ...@@ -177,7 +177,7 @@ class Task(abc.ABC):
cache_dir=None, cache_dir=None,
download_mode=None, download_mode=None,
config=None, config=None,
): ) -> None:
""" """
:param data_dir: str :param data_dir: str
Stores the path to a local folder containing the `Task`'s data files. Stores the path to a local folder containing the `Task`'s data files.
...@@ -188,7 +188,6 @@ class Task(abc.ABC): ...@@ -188,7 +188,6 @@ class Task(abc.ABC):
HuggingFace `datasets` API with the default cache directory located at: HuggingFace `datasets` API with the default cache directory located at:
`~/.cache/huggingface/datasets` `~/.cache/huggingface/datasets`
NOTE: You can change the cache location globally for a given process NOTE: You can change the cache location globally for a given process
by setting the shell environment variable, `HF_DATASETS_CACHE`,
to another directory: to another directory:
`export HF_DATASETS_CACHE="/path/to/another/directory"` `export HF_DATASETS_CACHE="/path/to/another/directory"`
:param download_mode: datasets.DownloadMode :param download_mode: datasets.DownloadMode
...@@ -219,7 +218,7 @@ class Task(abc.ABC): ...@@ -219,7 +218,7 @@ class Task(abc.ABC):
list(self.fewshot_docs()), self, rnd=random.Random(1234) list(self.fewshot_docs()), self, rnd=random.Random(1234)
) )
def download(self, data_dir=None, cache_dir=None, download_mode=None): def download(self, data_dir=None, cache_dir=None, download_mode=None) -> None:
"""Downloads and returns the task dataset. """Downloads and returns the task dataset.
Override this method to download the dataset from a custom API. Override this method to download the dataset from a custom API.
...@@ -328,7 +327,7 @@ class Task(abc.ABC): ...@@ -328,7 +327,7 @@ class Task(abc.ABC):
return rnd.sample(self._training_docs, k) return rnd.sample(self._training_docs, k)
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc) -> None:
print( print(
"Override doc_to_decontamination_query with document specific decontamination query." "Override doc_to_decontamination_query with document specific decontamination query."
) )
...@@ -342,7 +341,7 @@ class Task(abc.ABC): ...@@ -342,7 +341,7 @@ class Task(abc.ABC):
def doc_to_target(self, doc): def doc_to_target(self, doc):
pass pass
def build_all_requests(self, limit=None, rank=None, world_size=None): def build_all_requests(self, limit=None, rank=None, world_size=None) -> None:
"""Build a set of Instances for a task, and store them in task.instances""" """Build a set of Instances for a task, and store them in task.instances"""
if self.has_test_docs(): if self.has_test_docs():
docs = self.test_docs() docs = self.test_docs()
...@@ -478,7 +477,6 @@ class Task(abc.ABC): ...@@ -478,7 +477,6 @@ class Task(abc.ABC):
return labeled_examples + str(example) return labeled_examples + str(example)
def apply_filters(self): def apply_filters(self):
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
f.apply(self._instances) f.apply(self._instances)
...@@ -504,7 +502,7 @@ class ConfigurableTask(Task): ...@@ -504,7 +502,7 @@ class ConfigurableTask(Task):
def __init__( def __init__(
self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None
): # TODO no super() call here ) -> None: # TODO no super() call here
# Get pre-configured attributes # Get pre-configured attributes
self._config = self.CONFIG self._config = self.CONFIG
...@@ -576,7 +574,6 @@ class ConfigurableTask(Task): ...@@ -576,7 +574,6 @@ class ConfigurableTask(Task):
"aggregation" "aggregation"
] ]
else: else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()} INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = get_default_aggregation(metric_name) metric_agg = get_default_aggregation(metric_name)
eval_logger.warning( eval_logger.warning(
...@@ -689,8 +686,7 @@ class ConfigurableTask(Task): ...@@ -689,8 +686,7 @@ class ConfigurableTask(Task):
f'Both target_delimiter and target choice: "{choice}" does not have whitespace, ignore if the language you are evaluating on does not require/use whitespace' f'Both target_delimiter and target choice: "{choice}" does not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
) )
def download(self, dataset_kwargs=None): def download(self, dataset_kwargs=None) -> None:
self.dataset = datasets.load_dataset( self.dataset = datasets.load_dataset(
path=self.DATASET_PATH, path=self.DATASET_PATH,
name=self.DATASET_NAME, name=self.DATASET_NAME,
...@@ -782,7 +778,6 @@ class ConfigurableTask(Task): ...@@ -782,7 +778,6 @@ class ConfigurableTask(Task):
return doc return doc
def doc_to_text(self, doc): def doc_to_text(self, doc):
if self.prompt is not None: if self.prompt is not None:
doc_to_text = self.prompt doc_to_text = self.prompt
else: else:
...@@ -817,7 +812,6 @@ class ConfigurableTask(Task): ...@@ -817,7 +812,6 @@ class ConfigurableTask(Task):
raise TypeError raise TypeError
def doc_to_target(self, doc: dict) -> Union[int, str, list]: def doc_to_target(self, doc: dict) -> Union[int, str, list]:
if self.prompt is not None: if self.prompt is not None:
doc_to_target = self.prompt doc_to_target = self.prompt
else: else:
...@@ -859,7 +853,6 @@ class ConfigurableTask(Task): ...@@ -859,7 +853,6 @@ class ConfigurableTask(Task):
raise TypeError raise TypeError
def doc_to_choice(self, doc: Any) -> List[str]: def doc_to_choice(self, doc: Any) -> List[str]:
if self.prompt is not None: if self.prompt is not None:
doc_to_choice = self.prompt doc_to_choice = self.prompt
elif self._config.doc_to_choice is None: elif self._config.doc_to_choice is None:
...@@ -903,13 +896,11 @@ class ConfigurableTask(Task): ...@@ -903,13 +896,11 @@ class ConfigurableTask(Task):
def construct_requests( def construct_requests(
self, doc: dict, ctx: str, **kwargs self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]: ) -> Union[List[Instance], Instance]:
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
arguments = (ctx, self.doc_to_target(doc)) arguments = (ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
arguments = (self.doc_to_target(doc),) arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
target_delimiter = self._config.target_delimiter target_delimiter = self._config.target_delimiter
if self.multiple_input: if self.multiple_input:
...@@ -960,7 +951,6 @@ class ConfigurableTask(Task): ...@@ -960,7 +951,6 @@ class ConfigurableTask(Task):
) )
def process_results(self, doc, results): def process_results(self, doc, results):
if callable(self._config.process_results): if callable(self._config.process_results):
return self._config.process_results(doc, results) return self._config.process_results(doc, results)
...@@ -995,7 +985,6 @@ class ConfigurableTask(Task): ...@@ -995,7 +985,6 @@ class ConfigurableTask(Task):
), ),
} }
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results) lls, is_greedy = zip(*results)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
...@@ -1067,7 +1056,6 @@ class ConfigurableTask(Task): ...@@ -1067,7 +1056,6 @@ class ConfigurableTask(Task):
result_dict["acc_mutual_info"] = acc_mutual_info result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "greedy_until":
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
if self._config.doc_to_choice is not None: if self._config.doc_to_choice is not None:
# If you set doc_to_choice, # If you set doc_to_choice,
...@@ -1197,7 +1185,7 @@ class PerplexityTask(Task): ...@@ -1197,7 +1185,7 @@ class PerplexityTask(Task):
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc):
return doc return doc
def doc_to_text(self, doc): def doc_to_text(self, doc) -> str:
return "" return ""
def doc_to_target(self, doc): def doc_to_target(self, doc):
......
...@@ -11,8 +11,7 @@ from lm_eval.api.registry import ( ...@@ -11,8 +11,7 @@ from lm_eval.api.registry import (
) )
def include_benchmarks(task_dir): def include_benchmarks(task_dir: str) -> None:
for root, subdirs, file_list in os.walk(task_dir): for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0): if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
for f in file_list: for f in file_list:
......
import os import os
from typing import Any
import zstandard import zstandard
import json import json
import jsonlines import jsonlines
...@@ -9,7 +10,7 @@ import tqdm ...@@ -9,7 +10,7 @@ import tqdm
from pathlib import Path from pathlib import Path
def json_serial(obj): def json_serial(obj: Any) -> str:
"""JSON serializer for objects not serializable by default json code""" """JSON serializer for objects not serializable by default json code"""
if isinstance(obj, (datetime.datetime,)): if isinstance(obj, (datetime.datetime,)):
...@@ -19,7 +20,7 @@ def json_serial(obj): ...@@ -19,7 +20,7 @@ def json_serial(obj):
# Modified version of lm_dataformat Archive for single file. # Modified version of lm_dataformat Archive for single file.
class Archive: class Archive:
def __init__(self, file_path, compression_level=3): def __init__(self, file_path: str, compression_level: int = 3) -> None:
self.file_path = file_path self.file_path = file_path
dir_name = os.path.dirname(file_path) dir_name = os.path.dirname(file_path)
if dir_name: if dir_name:
...@@ -28,7 +29,7 @@ class Archive: ...@@ -28,7 +29,7 @@ class Archive:
self.cctx = zstandard.ZstdCompressor(level=compression_level) self.cctx = zstandard.ZstdCompressor(level=compression_level)
self.compressor = self.cctx.stream_writer(self.fh) self.compressor = self.cctx.stream_writer(self.fh)
def add_data(self, data, meta={}): def add_data(self, data, meta={}) -> None:
self.compressor.write( self.compressor.write(
json.dumps({"text": data, "meta": meta}, default=json_serial).encode( json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
"UTF-8" "UTF-8"
...@@ -36,7 +37,7 @@ class Archive: ...@@ -36,7 +37,7 @@ class Archive:
+ b"\n" + b"\n"
) )
def commit(self): def commit(self) -> None:
self.compressor.flush(zstandard.FLUSH_FRAME) self.compressor.flush(zstandard.FLUSH_FRAME)
self.fh.flush() self.fh.flush()
self.fh.close() self.fh.close()
...@@ -44,10 +45,16 @@ class Archive: ...@@ -44,10 +45,16 @@ class Archive:
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm. # Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader: class Reader:
def __init__(self): def __init__(self) -> None:
pass pass
def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner="\n\n"): def read(
self,
file,
get_meta: bool = False,
autojoin_paragraphs: bool = True,
para_joiner: str = "\n\n",
):
with open(file, "rb") as fh: with open(file, "rb") as fh:
self.fh = fh self.fh = fh
cctx = zstandard.ZstdDecompressor() cctx = zstandard.ZstdDecompressor()
...@@ -72,7 +79,7 @@ class Reader: ...@@ -72,7 +79,7 @@ class Reader:
class TextArchive: class TextArchive:
def __init__(self, file_path, mode="rb+"): def __init__(self, file_path, mode: str = "rb+") -> None:
self.file_path = file_path self.file_path = file_path
dir_name = os.path.dirname(file_path) dir_name = os.path.dirname(file_path)
if dir_name: if dir_name:
...@@ -83,21 +90,21 @@ class TextArchive: ...@@ -83,21 +90,21 @@ class TextArchive:
self.fh = open(self.file_path, mode) self.fh = open(self.file_path, mode)
def add_data(self, data): def add_data(self, data) -> None:
self.fh.write(data.encode("UTF-8") + b"\n") self.fh.write(data.encode("UTF-8") + b"\n")
def commit(self): def commit(self) -> None:
self.fh.flush() self.fh.flush()
self.fh.close() self.fh.close()
class TextReader: class TextReader:
def __init__(self, file_path): def __init__(self, file_path) -> None:
self.file_path = file_path self.file_path = file_path
# Optimized mmap read with infrequent tqdm updates to maintain speed # Optimized mmap read with infrequent tqdm updates to maintain speed
# Tested up to 250MB/s. # Tested up to 250MB/s.
def read_tqdm(self, update_frequency=10000): def read_tqdm(self, update_frequency: int = 10000):
current_file_position = 0 current_file_position = 0
line_counter = 0 line_counter = 0
with open(self.file_path, "r") as fh, tqdm.tqdm( with open(self.file_path, "r") as fh, tqdm.tqdm(
...@@ -149,7 +156,7 @@ class TextReader: ...@@ -149,7 +156,7 @@ class TextReader:
# Optimized for speed. Decompresses the archive in shell before # Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader. # using the mmap'd TextReader.
class ZStdTextReader: class ZStdTextReader:
def __init__(self, file): def __init__(self, file) -> None:
self.file = file self.file = file
def read_tqdm(self): def read_tqdm(self):
......
...@@ -11,7 +11,7 @@ from .archiver import ZStdTextReader ...@@ -11,7 +11,7 @@ from .archiver import ZStdTextReader
# Was used for testing the evaluator decoupled from the full logic below # Was used for testing the evaluator decoupled from the full logic below
def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size): def get_train_overlap_stub(docs: dict, ngrams_path: str, ngrams_n_size: str):
simulated_overlap = 0.1 simulated_overlap = 0.1
contaminated = int(len(docs) * simulated_overlap) contaminated = int(len(docs) * simulated_overlap)
return random.sample(range(len(docs)), contaminated) return random.sample(range(len(docs)), contaminated)
...@@ -25,6 +25,7 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size): ...@@ -25,6 +25,7 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst" # scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
# files. These should exist in the "ngrams_path" provided to this function. # files. These should exist in the "ngrams_path" provided to this function.
# Algorithm: # Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)} # 1. Build lookups for each dataset {ngram: list(document_ids)}
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]} # 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
...@@ -33,7 +34,7 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size): ...@@ -33,7 +34,7 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
# 4. Strip the task_set from the dictionary keys and return # 4. Strip the task_set from the dictionary keys and return
# #
# We cache the task+set lookups as well as the overlaps. # We cache the task+set lookups as well as the overlaps.
def get_train_overlap(docs_by_task_set, ngrams_path, limit): def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> dict:
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size) # return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
info_dict_path = os.path.join(ngrams_path, "info.json") info_dict_path = os.path.join(ngrams_path, "info.json")
...@@ -46,7 +47,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit): ...@@ -46,7 +47,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
print("Building Lookups...") print("Building Lookups...")
start = time.perf_counter() start = time.perf_counter()
def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit): def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) -> str:
return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps" return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"
lookups = {} lookups = {}
......
import re import re
import string import string
import timeit
import pickle import pickle
import traceback import traceback
from pprint import pprint from pprint import pprint
from typing import Iterator, Sequence, TypeVar
# This is a cpp module. Compile janitor_util.cpp with: # This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup # c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
...@@ -16,10 +16,12 @@ except Exception: ...@@ -16,10 +16,12 @@ except Exception:
traceback.print_exc() traceback.print_exc()
JANITOR_CPP = False JANITOR_CPP = False
T = TypeVar("T")
# Implementation from nltk source # Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html # https://www.nltk.org/_modules/nltk/util.html
def form_ngrams(sequence, n): def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[tuple[T, ...]]:
history = [] history = []
while n > 1: while n > 1:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
...@@ -36,7 +38,7 @@ def form_ngrams(sequence, n): ...@@ -36,7 +38,7 @@ def form_ngrams(sequence, n):
del history[0] del history[0]
def word_ngrams(s, n): def word_ngrams(s: str, n: int) -> Iterator[str]:
"""Splits a string into ngram words""" """Splits a string into ngram words"""
tokens = s.split() # not a generator :( tokens = s.split() # not a generator :(
ngram_seqs = form_ngrams(iter(tokens), n) ngram_seqs = form_ngrams(iter(tokens), n)
...@@ -68,14 +70,14 @@ def word_ngrams(s, n): ...@@ -68,14 +70,14 @@ def word_ngrams(s, n):
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python # https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def split_indices(s): def split_indices(s: str) -> Iterator[tuple[str, tuple[int, int]]]:
"""Splits a string on whitespaces and records the indices of each in the original string. """Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...) @:return generator((word, (start_idx, end_idx)), ...)
""" """
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s)) return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
def word_ngrams_indices(s, n): def word_ngrams_indices(s: str, n: int) -> Iterator[tuple[str, tuple[int, int]]]:
"""Splits a string into pairs of (ngram words, their start/end indices)""" """Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices = split_indices(s) tokens_with_indices = split_indices(s)
...@@ -104,16 +106,15 @@ def word_ngrams_indices(s, n): ...@@ -104,16 +106,15 @@ def word_ngrams_indices(s, n):
class Janitor: class Janitor:
# FIXME delete_chars: Should anything else go here? Special chars? # FIXME delete_chars: Should anything else go here? Special chars?
def __init__( def __init__(
self, self,
ngram_n=13, ngram_n: int = 13,
window_to_remove=200, window_to_remove: int = 200,
too_dirty_cutoff=10, too_dirty_cutoff: int = 10,
minimum_slice_length=200, minimum_slice_length: int = 200,
delete_chars=string.punctuation, delete_chars: str = string.punctuation,
): ) -> None:
self.ngram_n = ngram_n self.ngram_n = ngram_n
self.window_to_remove = window_to_remove self.window_to_remove = window_to_remove
self.too_dirty_cutoff = too_dirty_cutoff self.too_dirty_cutoff = too_dirty_cutoff
...@@ -135,11 +136,11 @@ class Janitor: ...@@ -135,11 +136,11 @@ class Janitor:
# I/O for saving contamination ngrams # I/O for saving contamination ngrams
############## ##############
def save_contamination_ngrams(self, filename): def save_contamination_ngrams(self, filename: str) -> None:
with open(filename, "wb") as fp: with open(filename, "wb") as fp:
pickle.dump(filename, fp) pickle.dump(filename, fp)
def load_contamination_ngrams(self, filename): def load_contamination_ngrams(self, filename: str) -> None:
with open(filename, "rb") as fp: with open(filename, "rb") as fp:
self.dirt_ngrams = pickle.load(fp) self.dirt_ngrams = pickle.load(fp)
...@@ -147,7 +148,7 @@ class Janitor: ...@@ -147,7 +148,7 @@ class Janitor:
# Call these :) # Call these :)
############## ##############
def register_contaminant(self, dirt_string): def register_contaminant(self, dirt_string: str) -> None:
"""Register a string as contamination to be removed, e.g. a test set """Register a string as contamination to be removed, e.g. a test set
This breaks the dirt_string into ngrams to store for future cleaning""" This breaks the dirt_string into ngrams to store for future cleaning"""
if JANITOR_CPP: if JANITOR_CPP:
...@@ -156,7 +157,7 @@ class Janitor: ...@@ -156,7 +157,7 @@ class Janitor:
print("WARNING: Janitor running in python mode") print("WARNING: Janitor running in python mode")
return self.register_contaminant_python(dirt_string) return self.register_contaminant_python(dirt_string)
def clean(self, dirty_string): def clean(self, dirty_string: str) -> list[str]:
"""Clean a string (e.g. a training set) by removing all ngrams previously """Clean a string (e.g. a training set) by removing all ngrams previously
registered as contaminants. Returns a list of clean chunks, or empty if registered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty""" the string was too dirty"""
...@@ -166,7 +167,9 @@ class Janitor: ...@@ -166,7 +167,9 @@ class Janitor:
print("WARNING: Janitor running in python mode") print("WARNING: Janitor running in python mode")
return self.clean_python(dirty_string) return self.clean_python(dirty_string)
def _split_chunks(self, dirty_string, dirty_parts): def _split_chunks(
self, dirty_string: str, dirty_parts: Sequence[tuple]
) -> list[str]:
clean_chunks = [] clean_chunks = []
splice_idx = 0 splice_idx = 0
end = -1 end = -1
...@@ -189,12 +192,12 @@ class Janitor: ...@@ -189,12 +192,12 @@ class Janitor:
# Fast C++ # Fast C++
############## ##############
def register_contaminant_cpp(self, dirt_string): def register_contaminant_cpp(self, dirt_string) -> None:
self.dirt_ngrams.update( self.dirt_ngrams.update(
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n) janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
) )
def clean_cpp(self, dirty_string): def clean_cpp(self, dirty_string: str) -> list[str]:
contamination_indices = janitor_util.clean_ngram_with_indices( contamination_indices = janitor_util.clean_ngram_with_indices(
dirty_string, self.delete_chars, self.ngram_n dirty_string, self.delete_chars, self.ngram_n
) )
...@@ -204,15 +207,15 @@ class Janitor: ...@@ -204,15 +207,15 @@ class Janitor:
# Slow python # Slow python
############## ##############
def normalize_string(self, s): def normalize_string(self, s: str) -> str:
return s.translate(self.translation_table) return s.translate(self.translation_table)
def register_contaminant_python(self, dirt_string): def register_contaminant_python(self, dirt_string: str) -> None:
self.dirt_ngrams.update( self.dirt_ngrams.update(
word_ngrams(self.normalize_string(dirt_string), self.ngram_n) word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
) )
def clean_python(self, dirty_string): def clean_python(self, dirty_string: str) -> list[str]:
contamination_indices = ( contamination_indices = (
(None, *idx_pair) (None, *idx_pair)
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n) for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
......
...@@ -42,11 +42,11 @@ def simple_evaluate( ...@@ -42,11 +42,11 @@ def simple_evaluate(
device=None, device=None,
use_cache=None, use_cache=None,
limit=None, limit=None,
bootstrap_iters=100000, bootstrap_iters: int = 100000,
check_integrity=False, check_integrity: bool = False,
decontamination_ngrams_path=None, decontamination_ngrams_path=None,
write_out=False, write_out: bool = False,
log_samples=True, log_samples: bool = True,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -117,7 +117,6 @@ def simple_evaluate( ...@@ -117,7 +117,6 @@ def simple_evaluate(
task_dict = lm_eval.tasks.get_task_dict(tasks) task_dict = lm_eval.tasks.get_task_dict(tasks)
for task_name in task_dict.keys(): for task_name in task_dict.keys():
task_obj = task_dict[task_name] task_obj = task_dict[task_name]
if type(task_obj) == tuple: if type(task_obj) == tuple:
group, task_obj = task_obj group, task_obj = task_obj
...@@ -175,10 +174,10 @@ def evaluate( ...@@ -175,10 +174,10 @@ def evaluate(
lm, lm,
task_dict, task_dict,
limit=None, limit=None,
bootstrap_iters=100000, bootstrap_iters: int = 100000,
decontamination_ngrams_path=None, decontamination_ngrams_path=None,
write_out=False, write_out: bool = False,
log_samples=True, log_samples: bool = True,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -223,7 +222,6 @@ def evaluate( ...@@ -223,7 +222,6 @@ def evaluate(
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
if type(task) == tuple: if type(task) == tuple:
group, task = task group, task = task
task_groups[task_name] = group task_groups[task_name] = group
...@@ -349,7 +347,6 @@ def evaluate( ...@@ -349,7 +347,6 @@ def evaluate(
# if multigpu, then gather data across all ranks # if multigpu, then gather data across all ranks
# first gather logged samples across all ranks # first gather logged samples across all ranks
for task_name, task_samples in list(samples.items()): for task_name, task_samples in list(samples.items()):
full_samples = [None] * lm.world_size full_samples = [None] * lm.world_size
torch.distributed.all_gather_object(full_samples, task_samples) torch.distributed.all_gather_object(full_samples, task_samples)
...@@ -358,7 +355,6 @@ def evaluate( ...@@ -358,7 +355,6 @@ def evaluate(
# then collect metrics across all ranks # then collect metrics across all ranks
vals_torch = collections.defaultdict(list) vals_torch = collections.defaultdict(list)
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
numitem = 0 numitem = 0
if type(items[0]) == tuple: if type(items[0]) == tuple:
numitem = len(items[0]) numitem = len(items[0])
......
...@@ -9,7 +9,7 @@ class DecontaminationFilter(Filter): ...@@ -9,7 +9,7 @@ class DecontaminationFilter(Filter):
name = "track_decontamination" name = "track_decontamination"
def __init__(self, path): def __init__(self, path) -> None:
""" """
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path"). TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
...@@ -17,7 +17,7 @@ class DecontaminationFilter(Filter): ...@@ -17,7 +17,7 @@ class DecontaminationFilter(Filter):
""" """
self._decontam_results = None self._decontam_results = None
def apply(self, reps, docs): def apply(self, resps, docs) -> None:
""" """
Return {"no_contamination", "only_contamination"} keys for the 2 different subsets Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
""" """
......
...@@ -6,7 +6,9 @@ from lm_eval.api.filter import Filter ...@@ -6,7 +6,9 @@ from lm_eval.api.filter import Filter
class RegexFilter(Filter): class RegexFilter(Filter):
""" """ """ """
def __init__(self, regex_pattern=r"#### (\-?[0-9\.\,]+)", fallback="[invalid]"): def __init__(
self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", fallback: str = "[invalid]"
) -> None:
""" """
pass a string `regex` to run `re.compile(r"regex")` on. pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located. `fallback` defines the output returned if no matches for the regex are located.
...@@ -41,12 +43,11 @@ class RegexFilter(Filter): ...@@ -41,12 +43,11 @@ class RegexFilter(Filter):
class WhitespaceFilter(Filter): class WhitespaceFilter(Filter):
""" """ """ """
def __init__(self): def __init__(self) -> None:
pass pass
def apply(self, resps, docs): def apply(self, resps, docs):
def filter_set(inst): def filter_set(inst):
filtered_resp = [] filtered_resp = []
for resp in inst: for resp in inst:
if resp.startswith(" "): if resp.startswith(" "):
......
...@@ -4,7 +4,7 @@ from lm_eval.api.filter import Filter ...@@ -4,7 +4,7 @@ from lm_eval.api.filter import Filter
class TakeFirstFilter(Filter): class TakeFirstFilter(Filter):
def __init__(self): def __init__(self) -> None:
""" """
Can define custom behavior here, if an individual instantiation of a Filter class should have state. Can define custom behavior here, if an individual instantiation of a Filter class should have state.
""" """
...@@ -17,8 +17,7 @@ class TakeFirstFilter(Filter): ...@@ -17,8 +17,7 @@ class TakeFirstFilter(Filter):
class TakeKFilter(Filter): class TakeKFilter(Filter):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs) -> None:
self.k = kwargs.pop("k") self.k = kwargs.pop("k")
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -32,7 +31,7 @@ class TakeKFilter(Filter): ...@@ -32,7 +31,7 @@ class TakeKFilter(Filter):
class MajorityVoteFilter(Filter): class MajorityVoteFilter(Filter):
def __init__(self): def __init__(self) -> None:
""" """
Can define custom behavior here, if an individual instantiation of a Filter class should have state. Can define custom behavior here, if an individual instantiation of a Filter class should have state.
""" """
......
...@@ -76,7 +76,7 @@ class AnthropicLM(LM): ...@@ -76,7 +76,7 @@ class AnthropicLM(LM):
max_tokens_to_sample: int = 256, max_tokens_to_sample: int = 256,
temperature: float = 0, # defaults to 1 temperature: float = 0, # defaults to 1
**kwargs, # top_p, top_k, etc. **kwargs, # top_p, top_k, etc.
): ) -> None:
"""Anthropic API wrapper. """Anthropic API wrapper.
:param model: str :param model: str
...@@ -135,11 +135,10 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -135,11 +135,10 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
def tok_decode(self, tokens: List[int]) -> str: def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(self, requests, disable_tqdm=False): def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
def greedy_until(self, requests) -> List[str]: def greedy_until(self, requests) -> List[str]:
if not requests: if not requests:
return [] return []
......
...@@ -5,7 +5,7 @@ from lm_eval.api.registry import register_model ...@@ -5,7 +5,7 @@ from lm_eval.api.registry import register_model
@register_model("dummy") @register_model("dummy")
class DummyLM(LM): class DummyLM(LM):
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
@classmethod @classmethod
......
...@@ -94,7 +94,7 @@ class HFLM(LM): ...@@ -94,7 +94,7 @@ class HFLM(LM):
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None, bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
gptq: Optional[Union[bool, str]] = False, gptq: Optional[Union[bool, str]] = False,
gptq_use_triton: Optional[bool] = False, gptq_use_triton: Optional[bool] = False,
): ) -> None:
super().__init__() super().__init__()
assert isinstance(device, str) assert isinstance(device, str)
...@@ -347,7 +347,7 @@ class HFLM(LM): ...@@ -347,7 +347,7 @@ class HFLM(LM):
return self._DEFAULT_MAX_LENGTH return self._DEFAULT_MAX_LENGTH
@property @property
def max_gen_toks(self): def max_gen_toks(self) -> int:
return 256 return 256
@property @property
...@@ -366,7 +366,7 @@ class HFLM(LM): ...@@ -366,7 +366,7 @@ class HFLM(LM):
def world_size(self): def world_size(self):
return self._world_size return self._world_size
def _detect_batch_size(self, requests=None, pos=0): def _detect_batch_size(self, requests=None, pos: int = 0):
if requests: if requests:
_, context_enc, continuation_enc = requests[pos] _, context_enc, continuation_enc = requests[pos]
max_length = len( max_length = len(
...@@ -432,11 +432,11 @@ class HFLM(LM): ...@@ -432,11 +432,11 @@ class HFLM(LM):
return encoding return encoding
def tok_batch_encode( def tok_batch_encode(
self, self,
strings: List[str], strings: List[str],
padding_side="left", padding_side: str = "left",
left_truncate_len=None, left_truncate_len: int = None,
truncation=False, truncation: bool = False,
): ):
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side old_padding_side = self.tokenizer.padding_side
...@@ -613,7 +613,9 @@ class HFLM(LM): ...@@ -613,7 +613,9 @@ class HFLM(LM):
return loglikelihoods return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False, override_bs=None): def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False, override_bs=None
):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
......
...@@ -69,7 +69,7 @@ class OpenaiCompletionsLM(LM): ...@@ -69,7 +69,7 @@ class OpenaiCompletionsLM(LM):
engine: str = "text-davinci-003", engine: str = "text-davinci-003",
truncate: bool = False, truncate: bool = False,
batch_size: int = 1, batch_size: int = 1,
): ) -> None:
""" """
:param engine: str :param engine: str
...@@ -99,12 +99,12 @@ class OpenaiCompletionsLM(LM): ...@@ -99,12 +99,12 @@ class OpenaiCompletionsLM(LM):
return self.end_of_text_token_id return self.end_of_text_token_id
@property @property
def max_length(self): def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
return 2048 return 2048
@property @property
def max_gen_toks(self): def max_gen_toks(self) -> int:
return 256 return 256
@property @property
...@@ -152,7 +152,7 @@ class OpenaiCompletionsLM(LM): ...@@ -152,7 +152,7 @@ class OpenaiCompletionsLM(LM):
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(new_reqs)
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, requests, disable_tqdm=False self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]: ) -> List[Tuple[float, bool]]:
res = [] res = []
......
...@@ -41,7 +41,7 @@ def textsynth_completion(**kwargs): ...@@ -41,7 +41,7 @@ def textsynth_completion(**kwargs):
@register_model("textsynth") @register_model("textsynth")
class TextSynthLM(LM): class TextSynthLM(LM):
def __init__(self, engine, truncate=False): def __init__(self, engine, truncate: bool = False) -> None:
""" """
:param engine: str :param engine: str
TextSynth API engine (e.g. `gptj_6B`) TextSynth API engine (e.g. `gptj_6B`)
...@@ -62,12 +62,12 @@ class TextSynthLM(LM): ...@@ -62,12 +62,12 @@ class TextSynthLM(LM):
raise NotImplementedError() raise NotImplementedError()
@property @property
def max_length(self): def max_length(self) -> int:
# NOTE: Turn on truncation to avoid errors on long inputs. # NOTE: Turn on truncation to avoid errors on long inputs.
return 2048 return 2048
@property @property
def max_gen_toks(self): def max_gen_toks(self) -> int:
return 256 return 256
@property @property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment