import collections import fnmatch import functools import importlib.util import inspect import logging import os import pathlib import re import subprocess import sys from itertools import islice from typing import ( Any, Callable, List, ) import numpy as np import yaml from jinja2 import BaseLoader, Environment, StrictUndefined logging.basicConfig( format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S", level=logging.INFO, ) eval_logger = logging.getLogger("lm-eval") SPACING = " " * 47 def escaped_split(text, sep_char, maxsplit=-1): """Split text into a list on occurrences of the given separation character `sep_char`. The separation character may be escaped by a backslash to avoid splitting at that location. The separation character must be a string of size 1. If `maxsplit` is given, at most `maxsplit` splits are done (thus, the list will have at most `maxsplit + 1` elements). If `maxsplit` is not specified or less than 0, then there is no limit on the number of splits (all possible splits are made). """ assert ( len(sep_char) == 1 ), "separation string must be a single character for escaped splitting" if maxsplit == 0: return text maxsplit = max(0, maxsplit) return re.split(r"(? so the first token has something to condition on :return: generator Generator of tuples (input_tokens, pred_tokens) Note: Score only the last len(pred_tokens) logits of the LM """ assert 1 <= context_len <= max_seq_len if not token_list: return # +1 offset, going from input->preds pred_len = max_seq_len - context_len + 1 predicted = 0 # Special handling for first window: predict all tokens first_seq_len = min(max_seq_len, len(token_list)) yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len]) predicted += first_seq_len while predicted < len(token_list): window_pred_len = min(len(token_list) - predicted, pred_len) window_end = predicted + window_pred_len yield ( token_list[window_end - max_seq_len - 1 : window_end - 1], token_list[window_end - window_pred_len : window_end], ) predicted += window_pred_len def make_disjoint_window(pair): """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation""" a, b = pair return a[: len(a) - (len(b) - 1)], b class Reorderer: def __init__(self, arr: List[Any], fn: Callable) -> None: """Reorder an array according to some function Args: arr (List[Any]): The initial array fn (Callable[[Any], Any]): A function to determine the priority of elements """ self.size = len(arr) arr = list(enumerate(arr)) arr = group(arr, lambda x: fn(x[1])) # arr = [([y[0] for y in x], x[0][1]) for x in arr] # TODO: overhaul reorderer. It currently grouped requests by content but we don't want this arr = [([y[0]], x[0][1]) for x in arr for y in x] arr.sort(key=lambda x: fn(x[1])) self.arr = arr def get_reordered(self): """Gets the reordered array Returns: List[Any]: The reordered array """ return [x[1] for x in self.arr] def get_original(self, newarr): """Restores the original order of a new array based on the old array's order Args: newarr (List[Any]): The array to be restored Returns: List[Any]: The array restored to the original order """ res = [None] * self.size cov = [False] * self.size for (inds, _), v in zip(self.arr, newarr): for ind in inds: res[ind] = v cov[ind] = True assert all(cov) return res def make_table(result_dict, column: str = "results"): """Generate table of results.""" from pytablewriter import LatexTableWriter, MarkdownTableWriter if column == "results": column_name = "Tasks" elif column == "groups": column_name = "Groups" all_headers = [ column_name, "Version", "Filter", "n-shot", "Metric", "Value", "", "Stderr", ] md_writer = MarkdownTableWriter() latex_writer = LatexTableWriter() md_writer.headers = all_headers latex_writer.headers = all_headers values = [] for k, dic in result_dict[column].items(): version = result_dict["versions"][k] n = str(result_dict["n-shot"][k]) if "alias" in dic: k = dic.pop("alias") for (mf), v in dic.items(): m, _, f = mf.partition(",") if m.endswith("_stderr"): continue if m + "_stderr" + "," + f in dic: se = dic[m + "_stderr" + "," + f] if se != "N/A": se = "%.4f" % se values.append([k, version, f, n, m, "%.4f" % v, "±", se]) else: values.append([k, version, f, n, m, "%.4f" % v, "", ""]) k = "" version = "" md_writer.value_matrix = values latex_writer.value_matrix = values # todo: make latex table look good # print(latex_writer.dumps()) return md_writer.dumps() def positional_deprecated(fn): """ A decorator to nudge users into passing only keyword args (`kwargs`) to the wrapped function, `fn`. """ @functools.wraps(fn) def _wrapper(*args, **kwargs): if len(args) != 1 if inspect.ismethod(fn) else 0: print( f"WARNING: using {fn.__name__} with positional arguments is " "deprecated and will be disallowed in a future version of " "lm-evaluation-harness!" ) return fn(*args, **kwargs) return _wrapper @positional_deprecated def find_test_root(start_path: pathlib.Path) -> pathlib.Path: """ Search upward in the directory tree to a maximum of three layers to find and return the package root (containing the 'tests' folder) """ cur_path = start_path.resolve() max_layers = 3 for _ in range(max_layers): if (cur_path / "tests" / "test_version_stable.py").exists(): return cur_path else: cur_path = cur_path.parent.resolve() raise FileNotFoundError( f"Unable to find package root within {max_layers} upwards" + f"of {start_path}" ) @positional_deprecated def run_task_tests(task_list: List[str]): """ Find the package root and run the tests for the given tasks """ import pytest package_root = find_test_root(start_path=pathlib.Path(__file__)) task_string = " or ".join(task_list) args = [ f"{package_root}/tests/test_version_stable.py", f"--rootdir={package_root}", "-k", f"{task_string}", ] sys.path.append(str(package_root)) pytest_return_val = pytest.main(args) if pytest_return_val: raise ValueError( f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}" ) def get_git_commit_hash(): """ Gets the git commit hash of your current repo (if it exists). Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42 """ try: git_hash = subprocess.check_output(["git", "describe", "--always"]).strip() git_hash = git_hash.decode() except subprocess.CalledProcessError or FileNotFoundError: # FileNotFoundError occurs when git not installed on system git_hash = None return git_hash def ignore_constructor(loader, node): return node def import_function(loader, node): function_name = loader.construct_scalar(node) yaml_path = os.path.dirname(loader.name) *module_name, function_name = function_name.split(".") if isinstance(module_name, list): module_name = ".".join(module_name) module_path = os.path.normpath(os.path.join(yaml_path, "{}.py".format(module_name))) spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) function = getattr(module, function_name) return function def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"): if mode == "simple": constructor_fn = ignore_constructor elif mode == "full": constructor_fn = import_function # Add the import_function constructor to the YAML loader yaml.add_constructor("!function", constructor_fn) if yaml_config is None: with open(yaml_path, "rb") as file: yaml_config = yaml.full_load(file) if yaml_dir is None: yaml_dir = os.path.dirname(yaml_path) assert yaml_dir is not None if "include" in yaml_config: include_path = yaml_config["include"] del yaml_config["include"] if isinstance(include_path, str): include_path = [include_path] # Load from the last one first include_path.reverse() final_yaml_config = {} for path in include_path: # Assumes that path is a full path. # If not found, assume the included yaml # is in the same dir as the original yaml if not os.path.isfile(path): path = os.path.join(yaml_dir, path) try: included_yaml_config = load_yaml_config(yaml_path=path, mode=mode) final_yaml_config.update(included_yaml_config) except Exception as ex: # If failed to load, ignore raise ex final_yaml_config.update(yaml_config) return final_yaml_config return yaml_config def regex_replace(string, pattern, repl, count: int = 0): """Implements the `re.sub` function as a custom Jinja filter.""" return re.sub(pattern, repl, string, count=count) env = Environment(loader=BaseLoader, undefined=StrictUndefined) env.filters["regex_replace"] = regex_replace def apply_template(template: str, doc: dict) -> str: rtemplate = env.from_string(template) return rtemplate.render(**doc) def create_iterator(raw_iterator, rank, world_size, limit=None): """ Method for creating a (potentially) sliced and limited iterator from a raw document iterator. Used for splitting data among ranks in multigpu setting or only pulling a sample of documents """ return islice(raw_iterator, rank, limit, world_size) # Multi-token stopping criteria # from more_itertools