Commit fc69d84f authored by Ethan Smith's avatar Ethan Smith
Browse files

Add suggestions from autotyping

This adds a bunch of simple annotations suggested by https://github.com/JelleZijlstra/autotyping.
parent da85f290
...@@ -15,7 +15,7 @@ from lm_eval.api.registry import ( ...@@ -15,7 +15,7 @@ from lm_eval.api.registry import (
) )
def register_configurable_task(config): def register_configurable_task(config: dict[str, str]) -> int:
SubClass = type( SubClass = type(
config["task"] + "ConfigurableTask", config["task"] + "ConfigurableTask",
(ConfigurableTask,), (ConfigurableTask,),
...@@ -38,7 +38,7 @@ def register_configurable_task(config): ...@@ -38,7 +38,7 @@ def register_configurable_task(config):
return 0 return 0
def check_prompt_config(config): def check_prompt_config(config: dict[str, str]) -> List[dict[str, str]]:
all_configs = [] all_configs = []
if "use_prompt" in config: if "use_prompt" in config:
prompt_list = prompts.load_prompt_list( prompt_list = prompts.load_prompt_list(
...@@ -69,14 +69,14 @@ def check_prompt_config(config): ...@@ -69,14 +69,14 @@ def check_prompt_config(config):
return all_configs return all_configs
def get_task_name_from_config(task_config): def get_task_name_from_config(task_config: dict[str, str]) -> str:
if "dataset_name" in task_config: if "dataset_name" in task_config:
return "{dataset_path}_{dataset_name}".format(**task_config) return "{dataset_path}_{dataset_name}".format(**task_config)
else: else:
return "{dataset_path}".format(**task_config) return "{dataset_path}".format(**task_config)
def include_task_folder(task_dir): def include_task_folder(task_dir: str) -> None:
""" """
Calling this function Calling this function
""" """
......
def doc_to_text(doc): def doc_to_text(doc) -> str:
return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format( return "{}\nQuestion: {} True, False or Neither?\nAnswer:".format(
doc["premise"], doc["premise"],
doc["hypothesis"].strip() doc["hypothesis"].strip()
......
...@@ -15,7 +15,7 @@ def _preproc_doc(doc): ...@@ -15,7 +15,7 @@ def _preproc_doc(doc):
return doc return doc
def doc_to_text(doc): def doc_to_text(doc) -> str:
doc = _preproc_doc(doc) doc = _preproc_doc(doc)
return f"Scenario 1: {doc['scenarios'][0]}\nScenario 2: {doc['scenarios'][1]}\nQuestion: Is Scenario 1 preferable?\nAnswer:" return f"Scenario 1: {doc['scenarios'][0]}\nScenario 2: {doc['scenarios'][1]}\nQuestion: Is Scenario 1 preferable?\nAnswer:"
......
def doc_to_text(doc): def doc_to_text(doc) -> str:
ctxs = "\n".join(doc["context"]["contexts"]) ctxs = "\n".join(doc["context"]["contexts"])
return "Abstract: {}\nQuestion: {}\nAnswer:".format( return "Abstract: {}\nQuestion: {}\nAnswer:".format(
ctxs, doc["question"], doc["final_decision"] ctxs, doc["question"], doc["final_decision"]
) )
def doc_to_target(doc): def doc_to_target(doc) -> str:
return " {}".format(doc["final_decision"]) return " {}".format(doc["final_decision"])
......
...@@ -10,7 +10,7 @@ import collections ...@@ -10,7 +10,7 @@ import collections
import importlib.util import importlib.util
import fnmatch import fnmatch
from typing import List, Literal, Union from typing import Iterator, List, Literal, Union
import gc import gc
import torch import torch
...@@ -65,7 +65,7 @@ def join_iters(iters): ...@@ -65,7 +65,7 @@ def join_iters(iters):
yield from iter yield from iter
def chunks(iter, n=0, fn=None): def chunks(iter, n: int = 0, fn=None):
arr = [] arr = []
for i, x in enumerate(iter): for i, x in enumerate(iter):
arr.append(x) arr.append(x)
...@@ -87,11 +87,11 @@ def group(arr, fn): ...@@ -87,11 +87,11 @@ def group(arr, fn):
class MultiChoice: class MultiChoice:
def __init__(self, choices): def __init__(self, choices) -> None:
self.choices = choices self.choices = choices
# Simple wildcard support (linux filename patterns) # Simple wildcard support (linux filename patterns)
def __contains__(self, values): def __contains__(self, values) -> bool:
for value in values.split(","): for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0: if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.info(f"Available tasks to choose:") eval_logger.info(f"Available tasks to choose:")
...@@ -100,7 +100,7 @@ class MultiChoice: ...@@ -100,7 +100,7 @@ class MultiChoice:
raise ValueError("'{}' is not in task list".format(value)) raise ValueError("'{}' is not in task list".format(value))
return True return True
def __iter__(self): def __iter__(self) -> Iterator:
for choice in self.choices: for choice in self.choices:
yield choice yield choice
...@@ -108,7 +108,6 @@ class MultiChoice: ...@@ -108,7 +108,6 @@ class MultiChoice:
# Returns a list containing all values of the source_list that # Returns a list containing all values of the source_list that
# match at least one of the patterns # match at least one of the patterns
def pattern_match(patterns, source_list): def pattern_match(patterns, source_list):
if type(patterns) == str: if type(patterns) == str:
patterns = [patterns] patterns = [patterns]
...@@ -177,7 +176,7 @@ def make_disjoint_window(pair): ...@@ -177,7 +176,7 @@ def make_disjoint_window(pair):
class Reorderer: class Reorderer:
def __init__(self, arr, fn): def __init__(self, arr, fn) -> None:
self.size = len(arr) self.size = len(arr)
arr = list(enumerate(arr)) arr = list(enumerate(arr))
arr = group(arr, lambda x: fn(x[1])) arr = group(arr, lambda x: fn(x[1]))
...@@ -212,7 +211,7 @@ class Grouper: ...@@ -212,7 +211,7 @@ class Grouper:
objects in `arr` satisfying `key == fn(ob)`. objects in `arr` satisfying `key == fn(ob)`.
""" """
def __init__(self, arr, fn): def __init__(self, arr, fn) -> None:
# self.orig_arr = arr # self.orig_arr = arr
self.size = len(arr) self.size = len(arr)
arr = list(enumerate(arr)) arr = list(enumerate(arr))
...@@ -263,7 +262,7 @@ class Grouper: ...@@ -263,7 +262,7 @@ class Grouper:
return res return res
def make_table(result_dict, column="results"): def make_table(result_dict, column: str = "results"):
"""Generate table of results.""" """Generate table of results."""
from pytablewriter import MarkdownTableWriter, LatexTableWriter from pytablewriter import MarkdownTableWriter, LatexTableWriter
...@@ -393,7 +392,6 @@ def get_git_commit_hash(): ...@@ -393,7 +392,6 @@ def get_git_commit_hash():
def import_function(loader, node): def import_function(loader, node):
function_name = loader.construct_scalar(node) function_name = loader.construct_scalar(node)
yaml_path = os.path.dirname(loader.name) yaml_path = os.path.dirname(loader.name)
...@@ -428,7 +426,6 @@ def load_yaml_config(yaml_path): ...@@ -428,7 +426,6 @@ def load_yaml_config(yaml_path):
include_path.reverse() include_path.reverse()
final_yaml_config = {} final_yaml_config = {}
for path in include_path: for path in include_path:
# Assumes that path is a full path. # Assumes that path is a full path.
# If not found, assume the included yaml # If not found, assume the included yaml
# is in the same dir as the original yaml # is in the same dir as the original yaml
...@@ -447,7 +444,7 @@ def load_yaml_config(yaml_path): ...@@ -447,7 +444,7 @@ def load_yaml_config(yaml_path):
return yaml_config return yaml_config
def regex_replace(string, pattern, repl, count=0): def regex_replace(string, pattern, repl, count: int = 0):
"""Implements the `re.sub` function as a custom Jinja filter.""" """Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count) return re.sub(pattern, repl, string, count=count)
...@@ -521,7 +518,7 @@ def pad_and_concat( ...@@ -521,7 +518,7 @@ def pad_and_concat(
return torch.cat(tensors, dim=0) return torch.cat(tensors, dim=0)
def clear_torch_cache(): def clear_torch_cache() -> None:
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -546,7 +543,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria): ...@@ -546,7 +543,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
tokenizer: transformers.PreTrainedTokenizer, tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int, initial_decoder_input_length: int,
batch_size: int, batch_size: int,
): ) -> None:
self.initial_decoder_input_length = initial_decoder_input_length self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size self.done_tracker = [False] * batch_size
self.sequence = sequence self.sequence = sequence
......
...@@ -15,7 +15,7 @@ from lm_eval.tasks import include_task_folder ...@@ -15,7 +15,7 @@ from lm_eval.tasks import include_task_folder
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
def parse_args(): def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, help="Name of model e.g. `hf`") parser.add_argument("--model", required=True, help="Name of model e.g. `hf`")
parser.add_argument( parser.add_argument(
...@@ -98,7 +98,7 @@ def parse_args(): ...@@ -98,7 +98,7 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def main(): def main() -> None:
args = parse_args() args = parse_args()
if args.limit: if args.limit:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment