Commit e0311cd5 authored by baberabb's avatar baberabb
Browse files

Merge remote-tracking branch 'origin/big-refactor' into big-refactor_python_final

parents 96c60cf6 f86d6874
...@@ -10,7 +10,7 @@ import collections ...@@ -10,7 +10,7 @@ import collections
import importlib.util import importlib.util
import fnmatch import fnmatch
from typing import List, Literal, Union from typing import Iterator, List, Literal, Union
import gc import gc
import torch import torch
...@@ -65,7 +65,7 @@ def join_iters(iters): ...@@ -65,7 +65,7 @@ def join_iters(iters):
yield from iter yield from iter
def chunks(iter, n=0, fn=None): def chunks(iter, n: int = 0, fn=None):
arr = [] arr = []
for i, x in enumerate(iter): for i, x in enumerate(iter):
arr.append(x) arr.append(x)
...@@ -87,11 +87,11 @@ def group(arr, fn): ...@@ -87,11 +87,11 @@ def group(arr, fn):
class MultiChoice: class MultiChoice:
def __init__(self, choices): def __init__(self, choices) -> None:
self.choices = choices self.choices = choices
# Simple wildcard support (linux filename patterns) # Simple wildcard support (linux filename patterns)
def __contains__(self, values): def __contains__(self, values) -> bool:
for value in values.split(","): for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0: if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.info(f"Available tasks to choose:") eval_logger.info(f"Available tasks to choose:")
...@@ -100,7 +100,7 @@ class MultiChoice: ...@@ -100,7 +100,7 @@ class MultiChoice:
raise ValueError("'{}' is not in task list".format(value)) raise ValueError("'{}' is not in task list".format(value))
return True return True
def __iter__(self): def __iter__(self) -> Iterator:
for choice in self.choices: for choice in self.choices:
yield choice yield choice
...@@ -108,7 +108,6 @@ class MultiChoice: ...@@ -108,7 +108,6 @@ class MultiChoice:
# Returns a list containing all values of the source_list that # Returns a list containing all values of the source_list that
# match at least one of the patterns # match at least one of the patterns
def pattern_match(patterns, source_list): def pattern_match(patterns, source_list):
if type(patterns) == str: if type(patterns) == str:
patterns = [patterns] patterns = [patterns]
...@@ -177,7 +176,7 @@ def make_disjoint_window(pair): ...@@ -177,7 +176,7 @@ def make_disjoint_window(pair):
class Reorderer: class Reorderer:
def __init__(self, arr, fn): def __init__(self, arr, fn) -> None:
self.size = len(arr) self.size = len(arr)
arr = list(enumerate(arr)) arr = list(enumerate(arr))
arr = group(arr, lambda x: fn(x[1])) arr = group(arr, lambda x: fn(x[1]))
...@@ -212,7 +211,7 @@ class Grouper: ...@@ -212,7 +211,7 @@ class Grouper:
objects in `arr` satisfying `key == fn(ob)`. objects in `arr` satisfying `key == fn(ob)`.
""" """
def __init__(self, arr, fn): def __init__(self, arr, fn) -> None:
# self.orig_arr = arr # self.orig_arr = arr
self.size = len(arr) self.size = len(arr)
arr = list(enumerate(arr)) arr = list(enumerate(arr))
...@@ -263,7 +262,7 @@ class Grouper: ...@@ -263,7 +262,7 @@ class Grouper:
return res return res
def make_table(result_dict, column="results"): def make_table(result_dict, column: str = "results"):
"""Generate table of results.""" """Generate table of results."""
from pytablewriter import MarkdownTableWriter, LatexTableWriter from pytablewriter import MarkdownTableWriter, LatexTableWriter
...@@ -393,7 +392,6 @@ def get_git_commit_hash(): ...@@ -393,7 +392,6 @@ def get_git_commit_hash():
def import_function(loader, node): def import_function(loader, node):
function_name = loader.construct_scalar(node) function_name = loader.construct_scalar(node)
yaml_path = os.path.dirname(loader.name) yaml_path = os.path.dirname(loader.name)
...@@ -428,7 +426,6 @@ def load_yaml_config(yaml_path): ...@@ -428,7 +426,6 @@ def load_yaml_config(yaml_path):
include_path.reverse() include_path.reverse()
final_yaml_config = {} final_yaml_config = {}
for path in include_path: for path in include_path:
# Assumes that path is a full path. # Assumes that path is a full path.
# If not found, assume the included yaml # If not found, assume the included yaml
# is in the same dir as the original yaml # is in the same dir as the original yaml
...@@ -447,7 +444,7 @@ def load_yaml_config(yaml_path): ...@@ -447,7 +444,7 @@ def load_yaml_config(yaml_path):
return yaml_config return yaml_config
def regex_replace(string, pattern, repl, count=0): def regex_replace(string, pattern, repl, count: int = 0):
"""Implements the `re.sub` function as a custom Jinja filter.""" """Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count) return re.sub(pattern, repl, string, count=count)
...@@ -521,7 +518,7 @@ def pad_and_concat( ...@@ -521,7 +518,7 @@ def pad_and_concat(
return torch.cat(tensors, dim=0) return torch.cat(tensors, dim=0)
def clear_torch_cache(): def clear_torch_cache() -> None:
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -546,7 +543,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria): ...@@ -546,7 +543,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
tokenizer: transformers.PreTrainedTokenizer, tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int, initial_decoder_input_length: int,
batch_size: int, batch_size: int,
): ) -> None:
self.initial_decoder_input_length = initial_decoder_input_length self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size self.done_tracker = [False] * batch_size
self.sequence = sequence self.sequence = sequence
......
...@@ -9,24 +9,26 @@ from pathlib import Path ...@@ -9,24 +9,26 @@ from pathlib import Path
from lm_eval import evaluator, utils from lm_eval import evaluator, utils
from lm_eval.api.registry import ALL_TASKS from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger, SPACING
from lm_eval.tasks import include_task_folder from lm_eval.tasks import include_task_folder
from lm_eval.benchmarks import include_benchmarks from lm_eval.benchmarks import include_benchmarks
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
def parse_args(): def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model", required=True, help="Name of model e.g. `hf`") parser.add_argument("--model", required=True, help="Name of model e.g. `hf`")
parser.add_argument(
"--tasks",
default=None,
help="Available Tasks:\n - {}".format("\n - ".join(sorted(ALL_TASKS))),
)
parser.add_argument( parser.add_argument(
"--model_args", "--model_args",
default="", default="",
help="String arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`", help="String arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
) )
parser.add_argument(
"--tasks", default=None # , choices=utils.MultiChoice(sorted(ALL_TASKS))
)
parser.add_argument( parser.add_argument(
"--num_fewshot", "--num_fewshot",
type=int, type=int,
...@@ -99,7 +101,7 @@ def parse_args(): ...@@ -99,7 +101,7 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def main(): def main() -> None:
args = parse_args() args = parse_args()
if args.limit: if args.limit:
...@@ -126,10 +128,21 @@ def main(): ...@@ -126,10 +128,21 @@ def main():
else: else:
tasks_list = args.tasks.split(",") tasks_list = args.tasks.split(",")
task_names = utils.pattern_match(tasks_list, ALL_TASKS) task_names = utils.pattern_match(tasks_list, ALL_TASKS)
task_missing = []
for task in [task for task in tasks_list if task not in task_names]: for task in [task for task in tasks_list if task not in task_names]:
if os.path.isfile(task): if os.path.isfile(task):
config = utils.load_yaml_config(task) config = utils.load_yaml_config(task)
task_names.append(config) task_names.append(config)
else:
task_missing.append(task)
if task_missing != []:
missing = ", ".join(task_missing)
eval_logger.error(
f"Tasks were not found: {missing}\n"
f"{SPACING}Try `lm-eval -h` for list of available tasks",
)
raise ValueError(f"Tasks {missing} were not found.")
if args.output_path: if args.output_path:
path = Path(args.output_path) path = Path(args.output_path)
......
[mypy]
python_version = 3.9
show_traceback = True
check_untyped_defs = True
no_implicit_reexport = True
warn_unreachable = True
warn_unused_configs = True
warn_unused_ignores = True
warn_redundant_casts = True
# We ignore errors everywhere to gradually add type annotations
[mypy-lm_eval.*]
ignore_errors = True
[mypy-lm_eval.api.*]
ignore_errors = True
[mypy-lm_eval.prompts.*]
ignore_errors = True
[mypy-lm_eval.models.*]
ignore_errors = True
[mypy-scripts.*]
ignore_errors = True
[mypy-main]
ignore_errors = True
...@@ -15,7 +15,7 @@ extras_require = { ...@@ -15,7 +15,7 @@ extras_require = {
], ],
"testing": ["pytest", "pytest-cov", "pytest-xdist"], "testing": ["pytest", "pytest-cov", "pytest-xdist"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"], "multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"], "sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1", "pycountry"],
"promptsource": [ "promptsource": [
"promptsource @ git+https://github.com/bigscience-workshop/promptsource.git#egg=promptsource" "promptsource @ git+https://github.com/bigscience-workshop/promptsource.git#egg=promptsource"
], ],
...@@ -53,7 +53,7 @@ setuptools.setup( ...@@ -53,7 +53,7 @@ setuptools.setup(
], ],
python_requires=">=3.9", python_requires=">=3.9",
install_requires=[ install_requires=[
"accelerate>=0.18.0", "accelerate>=0.21.0",
"evaluate", "evaluate",
"datasets>=2.0.0", "datasets>=2.0.0",
"evaluate>=0.4.0", "evaluate>=0.4.0",
...@@ -62,10 +62,9 @@ setuptools.setup( ...@@ -62,10 +62,9 @@ setuptools.setup(
"omegaconf>=2.2", "omegaconf>=2.2",
"peft>=0.2.0", "peft>=0.2.0",
"pybind11>=2.6.2", "pybind11>=2.6.2",
"pycountry",
"pytablewriter", "pytablewriter",
"rouge-score>=0.0.4", "rouge-score>=0.0.4",
"sacrebleu==1.5.0", "sacrebleu>=1.5.0",
"scikit-learn>=0.24.1", "scikit-learn>=0.24.1",
"sqlitedict", "sqlitedict",
"torch>=1.8", "torch>=1.8",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment