import collections import fnmatch import functools import hashlib import importlib.util import inspect import json import logging import os import re from dataclasses import asdict, is_dataclass from itertools import islice from typing import Any, Callable, Iterable, Iterator, List, Optional, Tuple, Union import numpy as np import yaml from jinja2 import BaseLoader, Environment, StrictUndefined logging.basicConfig( format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S", level=logging.INFO, ) eval_logger = logging.getLogger("lm-eval") SPACING = " " * 47 HIGHER_IS_BETTER_SYMBOLS = { True: "↑", False: "↓", } def hash_string(string: str) -> str: return hashlib.sha256(string.encode("utf-8")).hexdigest() def escaped_split(text, sep_char, maxsplit=-1): """Split text into a list on occurrences of the given separation character `sep_char`. The separation character may be escaped by a backslash to avoid splitting at that location. The separation character must be a string of size 1. If `maxsplit` is given, at most `maxsplit` splits are done (thus, the list will have at most `maxsplit + 1` elements). If `maxsplit` is not specified or less than 0, then there is no limit on the number of splits (all possible splits are made). """ assert ( len(sep_char) == 1 ), "separation string must be a single character for escaped splitting" if maxsplit == 0: return text maxsplit = max(0, maxsplit) return re.split(r"(? str: """ Given the sample results filenames, extracts and returns the task name. """ return filename[filename.find("_") + 1 : filename.rfind("_")] def get_file_datetime(filename: str) -> str: """ Given the results and sample results filenames, extracts and returns the datetime. """ return filename[filename.rfind("_") + 1 :].replace(".json", "") def sanitize_model_name(model_name: str) -> str: """ Given the model name, returns a sanitized version of it. """ return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name) def sanitize_task_name(task_name: str) -> str: """ Given the task name, returns a sanitized version of it. """ return re.sub(r"\W", "_", task_name) def get_latest_filename(filenames: List[str]) -> str: """ Given a list of filenames, returns the filename with the latest datetime. """ return max(filenames, key=lambda f: get_file_datetime(f)) def get_results_filenames(filenames: List[str]) -> List[str]: """ Extracts filenames that correspond to aggregated results. """ return [f for f in filenames if "/results_" in f and ".json" in f] def get_sample_results_filenames(filenames: List[str]) -> List[str]: """ Extracts filenames that correspond to sample results. """ return [f for f in filenames if "/samples_" in f and ".json" in f] def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len): """ - context_len allows for a rolling window context, allowing each prediction window to potentially condition on some context :param token_list: list List of tokens to be PREDICTED :param max_seq_len: int max_seq_len of model (or max_seq_len we want to use) :param context_len: int Amount of desired token context for prediction. Needs to be at least 1. :param prefix_token: token Dummy token like so the first token has something to condition on :return: generator Generator of tuples (input_tokens, pred_tokens) Note: Score only the last len(pred_tokens) logits of the LM """ assert 1 <= context_len <= max_seq_len if not token_list: return # +1 offset, going from input->preds pred_len = max_seq_len - context_len + 1 predicted = 0 # Special handling for first window: predict all tokens first_seq_len = min(max_seq_len, len(token_list)) yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len]) predicted += first_seq_len while predicted < len(token_list): window_pred_len = min(len(token_list) - predicted, pred_len) window_end = predicted + window_pred_len yield ( token_list[window_end - max_seq_len - 1 : window_end - 1], token_list[window_end - window_pred_len : window_end], ) predicted += window_pred_len def make_disjoint_window(pair): """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation""" a, b = pair return a[: len(a) - (len(b) - 1)], b class EnhancedJSONEncoder(json.JSONEncoder): """ Provides a proper json encoding for the loggers and trackers json dumps. Notably manages the json encoding of dataclasses. """ def default(self, o): if is_dataclass(o): return asdict(o) return super().default(o) class Reorderer: def __init__(self, arr: List[Any], fn: Callable) -> None: """Reorder an array according to some function Args: arr (List[Any]): The initial array fn (Callable[[Any], Any]): A function to determine the priority of elements """ self.size = len(arr) arr = list(enumerate(arr)) arr = group(arr, lambda x: fn(x[1])) # arr = [([y[0] for y in x], x[0][1]) for x in arr] # TODO: overhaul reorderer. It currently grouped requests by content but we don't want this arr = [([y[0]], x[0][1]) for x in arr for y in x] arr.sort(key=lambda x: fn(x[1])) self.arr = arr def get_reordered(self): """Gets the reordered array Returns: List[Any]: The reordered array """ return [x[1] for x in self.arr] def get_original(self, newarr): """Restores the original order of a new array based on the old array's order Args: newarr (List[Any]): The array to be restored Returns: List[Any]: The array restored to the original order """ res = [None] * self.size cov = [False] * self.size for (inds, _), v in zip(self.arr, newarr): for ind in inds: res[ind] = v cov[ind] = True assert all(cov) return res def make_table(result_dict, column: str = "results", sort_results: bool = True): """Generate table of results.""" from pytablewriter import LatexTableWriter, MarkdownTableWriter if column == "results": column_name = "Tasks" elif column == "groups": column_name = "Groups" all_headers = [ column_name, "Version", "Filter", "n-shot", "Metric", "", "Value", "", "Stderr", ] md_writer = MarkdownTableWriter() latex_writer = LatexTableWriter() md_writer.headers = all_headers latex_writer.headers = all_headers values = [] keys = result_dict[column].keys() if sort_results: # sort entries alphabetically keys = sorted(keys) for k in keys: dic = result_dict[column][k] version = result_dict["versions"].get(k, "N/A") n = str(result_dict["n-shot"][k]) higher_is_better = result_dict.get("higher_is_better", {}).get(k, {}) if "alias" in dic: k = dic.pop("alias") metric_items = dic.items() if sort_results: metric_items = sorted(metric_items) for (mf), v in metric_items: m, _, f = mf.partition(",") if m.endswith("_stderr"): continue hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "") if m + "_stderr" + "," + f in dic: se = dic[m + "_stderr" + "," + f] if se != "N/A": se = "%.4f" % se if isinstance(v, dict): for v_key, v_v in v.items(): values.append( [k, version, f, n, m + "_" + v_key, "%.4f" % v_v, "±", se] ) else: values.append([k, version, f, n, m, hib, "%.4f" % v, "±", se]) else: values.append([k, version, f, n, m, hib, "%.4f" % v, "", ""]) k = "" version = "" md_writer.value_matrix = values latex_writer.value_matrix = values # todo: make latex table look good # print(latex_writer.dumps()) return md_writer.dumps() def positional_deprecated(fn): """ A decorator to nudge users into passing only keyword args (`kwargs`) to the wrapped function, `fn`. """ @functools.wraps(fn) def _wrapper(*args, **kwargs): if len(args) != 1 if inspect.ismethod(fn) else 0: print( f"WARNING: using {fn.__name__} with positional arguments is " "deprecated and will be disallowed in a future version of " "lm-evaluation-harness!" ) return fn(*args, **kwargs) return _wrapper def ignore_constructor(loader, node): return node def import_function(loader, node): function_name = loader.construct_scalar(node) yaml_path = os.path.dirname(loader.name) *module_name, function_name = function_name.split(".") if isinstance(module_name, list): module_name = ".".join(module_name) module_path = os.path.normpath(os.path.join(yaml_path, "{}.py".format(module_name))) spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) function = getattr(module, function_name) return function def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"): if mode == "simple": constructor_fn = ignore_constructor elif mode == "full": constructor_fn = import_function # Add the import_function constructor to the YAML loader yaml.add_constructor("!function", constructor_fn) if yaml_config is None: with open(yaml_path, "rb") as file: yaml_config = yaml.full_load(file) if yaml_dir is None: yaml_dir = os.path.dirname(yaml_path) assert yaml_dir is not None if "include" in yaml_config: include_path = yaml_config["include"] del yaml_config["include"] if isinstance(include_path, str): include_path = [include_path] # Load from the last one first include_path.reverse() final_yaml_config = {} for path in include_path: # Assumes that path is a full path. # If not found, assume the included yaml # is in the same dir as the original yaml if not os.path.isfile(path): path = os.path.join(yaml_dir, path) try: included_yaml_config = load_yaml_config(yaml_path=path, mode=mode) final_yaml_config.update(included_yaml_config) except Exception as ex: # If failed to load, ignore raise ex final_yaml_config.update(yaml_config) return final_yaml_config return yaml_config def regex_replace(string, pattern, repl, count: int = 0): """Implements the `re.sub` function as a custom Jinja filter.""" return re.sub(pattern, repl, string, count=count) env = Environment(loader=BaseLoader, undefined=StrictUndefined) env.filters["regex_replace"] = regex_replace def apply_template(template: str, doc: dict) -> str: rtemplate = env.from_string(template) return rtemplate.render(**doc) def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None): """ Method for creating a (potentially) sliced and limited iterator from a raw document iterator. Used for splitting data among ranks in multigpu setting or only pulling a sample of documents """ return islice(raw_iterator, rank, limit, world_size) class Collator: """ A class for reordering and batching elements of an array. This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data. """ def __init__( self, arr: List, sort_fn: Callable, group_fn: Callable = lambda x: x[1], grouping: bool = False, ) -> None: self.grouping = grouping self.fn = sort_fn self.group_fn = lambda x: group_fn(x[1]) # first index are enumerated indices self.reorder_indices: List = [] self.size = len(arr) self.arr_with_indices: Iterable[Any] = tuple(enumerate(arr)) # [indices, (arr)] if self.grouping is True: self.group_by_index() def group_by_index(self) -> None: self.arr_with_indices = self.group( self.arr_with_indices, fn=self.group_fn, values=False ) def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator: """ Generates and yields batches from the reordered array. Parameters: - n (int): The size of each batch. Defaults to 1. - batch_fn (Optional[Callable[[int, Iterable], int]]): A function to determine the size of each batch. Defaults to None. Yields: Iterator: An iterator over batches of reordered elements. """ if self.grouping: for ( key, values, ) in self.arr_with_indices.items(): # type: ignore values = self._reorder(values) batch = self.get_chunks(values, n=n, fn=batch_fn) yield from batch else: values = self._reorder(self.arr_with_indices) # type: ignore batch = self.get_chunks(values, n=n, fn=batch_fn) yield from batch def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> List: """ Reorders the elements in the array based on the sorting function. Parameters: - arr (Union[List, Tuple[Tuple[int, Any], ...]]): The array or iterable to be reordered. Yields: List: Yields reordered elements one by one. """ arr = sorted(arr, key=lambda x: self.fn(x[1])) self.reorder_indices.extend([x[0] for x in arr]) yield from [x[1] for x in arr] def get_original(self, newarr: List) -> List: """ Restores the original order of elements from the reordered list. Parameters: - newarr (List): The reordered array. Returns: List: The array with elements restored to their original order. """ res = [None] * self.size cov = [False] * self.size for ind, v in zip(self.reorder_indices, newarr): res[ind] = v cov[ind] = True assert all(cov) return res def __len__(self): return self.size @staticmethod def group(arr: Iterable, fn: Callable, values: bool = False) -> Iterable: """ Groups elements of an iterable based on a provided function. Parameters: - arr (Iterable): The iterable to be grouped. - fn (Callable): The function to determine the grouping. - values (bool): If True, returns the values of the group. Defaults to False. Returns: Iterable: An iterable of grouped elements. """ res = collections.defaultdict(list) for ob in arr: try: hashable_dict = tuple( ( key, tuple(value) if isinstance(value, collections.abc.Iterable) else value, ) for key, value in sorted(fn(ob).items()) ) res[hashable_dict].append(ob) except TypeError: res[fn(ob)].append(ob) if not values: return res return res.values() @staticmethod def get_chunks(_iter, n: int = 0, fn=None): """ Divides an iterable into chunks of specified size or based on a given function. Useful for batching Parameters: - iter: The input iterable to be divided into chunks. - n: An integer representing the size of each chunk. Default is 0. - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None. Returns: An iterator that yields chunks of the input iterable. Example usage: ``` data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] for chunk in chunks(data, 3): print(chunk) ``` Output: ``` [1, 2, 3] [4, 5, 6] [7, 8, 9] [10] ``` """ arr = [] _iter = tuple(_iter) for i, x in enumerate(_iter): arr.append(x) if len(arr) == (fn(i, _iter) if fn else n): yield arr arr = [] if arr: yield arr