Commit e0311cd5 authored by baberabb's avatar baberabb
Browse files

Merge remote-tracking branch 'origin/big-refactor' into big-refactor_python_final

parents 96c60cf6 f86d6874
......@@ -10,7 +10,7 @@ import collections
import importlib.util
import fnmatch
from typing import List, Literal, Union
from typing import Iterator, List, Literal, Union
import gc
import torch
......@@ -65,7 +65,7 @@ def join_iters(iters):
yield from iter
def chunks(iter, n=0, fn=None):
def chunks(iter, n: int = 0, fn=None):
arr = []
for i, x in enumerate(iter):
arr.append(x)
......@@ -87,11 +87,11 @@ def group(arr, fn):
class MultiChoice:
def __init__(self, choices):
def __init__(self, choices) -> None:
self.choices = choices
# Simple wildcard support (linux filename patterns)
def __contains__(self, values):
def __contains__(self, values) -> bool:
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.info(f"Available tasks to choose:")
......@@ -100,7 +100,7 @@ class MultiChoice:
raise ValueError("'{}' is not in task list".format(value))
return True
def __iter__(self):
def __iter__(self) -> Iterator:
for choice in self.choices:
yield choice
......@@ -108,7 +108,6 @@ class MultiChoice:
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
if type(patterns) == str:
patterns = [patterns]
......@@ -177,7 +176,7 @@ def make_disjoint_window(pair):
class Reorderer:
def __init__(self, arr, fn):
def __init__(self, arr, fn) -> None:
self.size = len(arr)
arr = list(enumerate(arr))
arr = group(arr, lambda x: fn(x[1]))
......@@ -212,7 +211,7 @@ class Grouper:
objects in `arr` satisfying `key == fn(ob)`.
"""
def __init__(self, arr, fn):
def __init__(self, arr, fn) -> None:
# self.orig_arr = arr
self.size = len(arr)
arr = list(enumerate(arr))
......@@ -263,7 +262,7 @@ class Grouper:
return res
def make_table(result_dict, column="results"):
def make_table(result_dict, column: str = "results"):
"""Generate table of results."""
from pytablewriter import MarkdownTableWriter, LatexTableWriter
......@@ -393,7 +392,6 @@ def get_git_commit_hash():
def import_function(loader, node):
function_name = loader.construct_scalar(node)
yaml_path = os.path.dirname(loader.name)
......@@ -428,7 +426,6 @@ def load_yaml_config(yaml_path):
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
......@@ -447,7 +444,7 @@ def load_yaml_config(yaml_path):
return yaml_config
def regex_replace(string, pattern, repl, count=0):
def regex_replace(string, pattern, repl, count: int = 0):
"""Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count)
......@@ -521,7 +518,7 @@ def pad_and_concat(
return torch.cat(tensors, dim=0)
def clear_torch_cache():
def clear_torch_cache() -> None:
gc.collect()
torch.cuda.empty_cache()
......@@ -546,7 +543,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
):
) -> None:
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
......
......@@ -9,24 +9,26 @@ from pathlib import Path
from lm_eval import evaluator, utils
from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger
from lm_eval.logger import eval_logger, SPACING
from lm_eval.tasks import include_task_folder
from lm_eval.benchmarks import include_benchmarks
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def parse_args():
parser = argparse.ArgumentParser()
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model", required=True, help="Name of model e.g. `hf`")
parser.add_argument(
"--tasks",
default=None,
help="Available Tasks:\n - {}".format("\n - ".join(sorted(ALL_TASKS))),
)
parser.add_argument(
"--model_args",
default="",
help="String arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
)
parser.add_argument(
"--tasks", default=None # , choices=utils.MultiChoice(sorted(ALL_TASKS))
)
parser.add_argument(
"--num_fewshot",
type=int,
......@@ -99,7 +101,7 @@ def parse_args():
return parser.parse_args()
def main():
def main() -> None:
args = parse_args()
if args.limit:
......@@ -126,10 +128,21 @@ def main():
else:
tasks_list = args.tasks.split(",")
task_names = utils.pattern_match(tasks_list, ALL_TASKS)
task_missing = []
for task in [task for task in tasks_list if task not in task_names]:
if os.path.isfile(task):
config = utils.load_yaml_config(task)
task_names.append(config)
else:
task_missing.append(task)
if task_missing != []:
missing = ", ".join(task_missing)
eval_logger.error(
f"Tasks were not found: {missing}\n"
f"{SPACING}Try `lm-eval -h` for list of available tasks",
)
raise ValueError(f"Tasks {missing} were not found.")
if args.output_path:
path = Path(args.output_path)
......
[mypy]
python_version = 3.9
show_traceback = True
check_untyped_defs = True
no_implicit_reexport = True
warn_unreachable = True
warn_unused_configs = True
warn_unused_ignores = True
warn_redundant_casts = True
# We ignore errors everywhere to gradually add type annotations
[mypy-lm_eval.*]
ignore_errors = True
[mypy-lm_eval.api.*]
ignore_errors = True
[mypy-lm_eval.prompts.*]
ignore_errors = True
[mypy-lm_eval.models.*]
ignore_errors = True
[mypy-scripts.*]
ignore_errors = True
[mypy-main]
ignore_errors = True
......@@ -15,7 +15,7 @@ extras_require = {
],
"testing": ["pytest", "pytest-cov", "pytest-xdist"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1", "pycountry"],
"promptsource": [
"promptsource @ git+https://github.com/bigscience-workshop/promptsource.git#egg=promptsource"
],
......@@ -53,7 +53,7 @@ setuptools.setup(
],
python_requires=">=3.9",
install_requires=[
"accelerate>=0.18.0",
"accelerate>=0.21.0",
"evaluate",
"datasets>=2.0.0",
"evaluate>=0.4.0",
......@@ -62,10 +62,9 @@ setuptools.setup(
"omegaconf>=2.2",
"peft>=0.2.0",
"pybind11>=2.6.2",
"pycountry",
"pytablewriter",
"rouge-score>=0.0.4",
"sacrebleu==1.5.0",
"sacrebleu>=1.5.0",
"scikit-learn>=0.24.1",
"sqlitedict",
"torch>=1.8",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment