import os import re import sys import yaml import inspect import pathlib import functools import subprocess import collections import importlib.util import fnmatch from typing import Iterator, List, Literal, Union import gc import torch import transformers from jinja2 import BaseLoader, Environment, StrictUndefined from itertools import islice from lm_eval.logger import eval_logger def escaped_split(text, sep_char, maxsplit=-1): """Split text into a list on occurrences of the given separation character `sep_char`. The separation character may be escaped by a backslash to avoid splitting at that location. The separation character must be a string of size 1. If `maxsplit` is given, at most `maxsplit` splits are done (thus, the list will have at most `maxsplit + 1` elements). If `maxsplit` is not specified or less than 0, then there is no limit on the number of splits (all possible splits are made). """ assert ( len(sep_char) == 1 ), "separation string must be a single character for escaped splitting" if maxsplit == 0: return text maxsplit = max(0, maxsplit) return re.split(r"(? None: self.choices = choices # Simple wildcard support (linux filename patterns) def __contains__(self, values) -> bool: for value in values.split(","): if len(fnmatch.filter(self.choices, value)) == 0: eval_logger.info(f"Available tasks to choose:") for choice in self.choices: eval_logger.info(f" - {choice}") raise ValueError("'{}' is not in task list".format(value)) return True def __iter__(self) -> Iterator: for choice in self.choices: yield choice # Returns a list containing all values of the source_list that # match at least one of the patterns def pattern_match(patterns, source_list): if type(patterns) == str: patterns = [patterns] task_names = set() for pattern in patterns: for matching in fnmatch.filter(source_list, pattern): task_names.add(matching) return sorted(list(task_names)) def general_detokenize(string): string = string.replace(" n't", "n't") string = string.replace(" )", ")") string = string.replace("( ", "(") string = string.replace('" ', '"') string = string.replace(' "', '"') string = re.sub(r" (['.,])", r"\1", string) return string def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len): """ - context_len allows for a rolling window context, allowing each prediction window to potentially condition on some context :param token_list: list List of tokens to be PREDICTED :param max_seq_len: int max_seq_len of model (or max_seq_len we want to use) :param context_len: int Amount of desired token context for prediction. Needs to be at least 1. :param prefix_token: token Dummy token like so the first token has something to condition on :return: generator Generator of tuples (input_tokens, pred_tokens) Note: Score only the last len(pred_tokens) logits of the LM """ assert 1 <= context_len <= max_seq_len if not token_list: return # +1 offset, going from input->preds pred_len = max_seq_len - context_len + 1 predicted = 0 # Special handling for first window: predict all tokens first_seq_len = min(max_seq_len, len(token_list)) yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len]) predicted += first_seq_len while predicted < len(token_list): window_pred_len = min(len(token_list) - predicted, pred_len) window_end = predicted + window_pred_len yield ( token_list[window_end - max_seq_len - 1 : window_end - 1], token_list[window_end - window_pred_len : window_end], ) predicted += window_pred_len def make_disjoint_window(pair): """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation""" a, b = pair return a[: len(a) - (len(b) - 1)], b class Reorderer: def __init__(self, arr, fn) -> None: self.size = len(arr) arr = list(enumerate(arr)) arr = group(arr, lambda x: fn(x[1])) # arr = [([y[0] for y in x], x[0][1]) for x in arr] # TODO: overhaul reorderer. It currently grouped requests by content but we don't want this arr = [([y[0]], x[0][1]) for x in arr for y in x] arr.sort(key=lambda x: fn(x[1])) self.arr = arr def get_reordered(self): return [x[1] for x in self.arr] def get_original(self, newarr): res = [None] * self.size cov = [False] * self.size for (inds, _), v in zip(self.arr, newarr): for ind in inds: res[ind] = v cov[ind] = True assert all(cov) return res class Grouper: """ takes an array `arr` and function `fn` and returns a dictionary with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all objects in `arr` satisfying `key == fn(ob)`. """ def __init__(self, arr, fn) -> None: # self.orig_arr = arr self.size = len(arr) arr = list(enumerate(arr)) def group_return_dict(arr, fn): res = collections.defaultdict(list) for ob in arr: res[fn(ob)].append(ob) return res arr = group_return_dict(arr, lambda x: fn(x[1])) # self.arr has format Dict[Tuple[int, ]] self.arr = arr self._grouped = None def get_grouped(self): # return the contents but not indices for our grouped dict. if self._grouped: return self._grouped grouped = {} for key in self.arr.keys(): # drop the index from each element of self.arr grouped[key] = [y[1] for y in self.arr[key]] self._grouped = grouped return grouped def get_original(self, grouped_dict): # take in a grouped dictionary with e.g. results for each key listed # in the same order as the instances in `self.arr`, and # return the results in the same (single list) order as `self.orig_arr`. res = [None] * self.size cov = [False] * self.size # orig = [None] * self.size assert grouped_dict.keys() == self.arr.keys() for key in grouped_dict.keys(): for (ind, _), v in zip(self.arr[key], grouped_dict[key]): res[ind] = v cov[ind] = True # orig[ind] = _ assert all(cov) # assert orig == self.orig_arr return res def make_table(result_dict, column: str = "results"): """Generate table of results.""" from pytablewriter import MarkdownTableWriter, LatexTableWriter if column == "results": column_name = "Tasks" elif column == "groups": column_name = "Groups" md_writer = MarkdownTableWriter() latex_writer = LatexTableWriter() md_writer.headers = [ column_name, "Version", "Filter", "Metric", "Value", "", "Stderr", ] latex_writer.headers = [ column_name, "Version", "Filter", "Metric", "Value", "", "Stderr", ] values = [] for k, dic in result_dict[column].items(): version = result_dict["versions"][k] for (mf), v in dic.items(): m, _, f = mf.partition(",") if m.endswith("_stderr"): continue if m + "_stderr" + "," + f in dic: se = dic[m + "_stderr" + "," + f] values.append([k, version, f, m, "%.4f" % v, "±", "%.4f" % se]) else: values.append([k, version, f, m, "%.4f" % v, "", ""]) k = "" version = "" md_writer.value_matrix = values latex_writer.value_matrix = values # todo: make latex table look good # print(latex_writer.dumps()) return md_writer.dumps() def positional_deprecated(fn): """ A decorator to nudge users into passing only keyword args (`kwargs`) to the wrapped function, `fn`. """ @functools.wraps(fn) def _wrapper(*args, **kwargs): if len(args) != 1 if inspect.ismethod(fn) else 0: print( f"WARNING: using {fn.__name__} with positional arguments is " "deprecated and will be disallowed in a future version of " "lm-evaluation-harness!" ) return fn(*args, **kwargs) return _wrapper @positional_deprecated def find_test_root(start_path: pathlib.Path) -> pathlib.Path: """ Search upward in the directory tree to a maximum of three layers to find and return the package root (containing the 'tests' folder) """ cur_path = start_path.resolve() max_layers = 3 for _ in range(max_layers): if (cur_path / "tests" / "test_version_stable.py").exists(): return cur_path else: cur_path = cur_path.parent.resolve() raise FileNotFoundError( f"Unable to find package root within {max_layers} upwards" + f"of {start_path}" ) @positional_deprecated def run_task_tests(task_list: List[str]): """ Find the package root and run the tests for the given tasks """ import pytest package_root = find_test_root(start_path=pathlib.Path(__file__)) task_string = " or ".join(task_list) args = [ f"{package_root}/tests/test_version_stable.py", f"--rootdir={package_root}", "-k", f"{task_string}", ] sys.path.append(str(package_root)) pytest_return_val = pytest.main(args) if pytest_return_val: raise ValueError( f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}" ) def get_git_commit_hash(): """ Gets the git commit hash of your current repo (if it exists). Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42 """ try: git_hash = subprocess.check_output(["git", "describe", "--always"]).strip() git_hash = git_hash.decode() except subprocess.CalledProcessError or FileNotFoundError: # FileNotFoundError occurs when git not installed on system git_hash = None return git_hash def import_function(loader, node): function_name = loader.construct_scalar(node) yaml_path = os.path.dirname(loader.name) *module_name, function_name = function_name.split(".") if type(module_name) == list: module_name = ".".join(module_name) module_path = os.path.normpath(os.path.join(yaml_path, "{}.py".format(module_name))) spec = importlib.util.spec_from_file_location(module_name, module_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) function = getattr(module, function_name) return function # Add the import_function constructor to the YAML loader yaml.add_constructor("!function", import_function) def load_yaml_config(yaml_path): with open(yaml_path, "rb") as file: yaml_config = yaml.full_load(file) yaml_dir = os.path.dirname(yaml_path) if "include" in yaml_config: include_path = yaml_config["include"] del yaml_config["include"] if type(include_path) == str: include_path = [include_path] # Load from the last one first include_path.reverse() final_yaml_config = {} for path in include_path: # Assumes that path is a full path. # If not found, assume the included yaml # is in the same dir as the original yaml if not os.path.isfile(path): path = os.path.normpath(os.path.join(yaml_dir, path)) try: included_yaml_config = load_yaml_config(path) final_yaml_config.update(included_yaml_config) except Exception as ex: # If failed to load, ignore raise ex final_yaml_config.update(yaml_config) return final_yaml_config return yaml_config def regex_replace(string, pattern, repl, count: int = 0): """Implements the `re.sub` function as a custom Jinja filter.""" return re.sub(pattern, repl, string, count=count) env = Environment(loader=BaseLoader, undefined=StrictUndefined) env.filters["regex_replace"] = regex_replace def apply_template(template: str, doc: dict) -> str: rtemplate = env.from_string(template) return rtemplate.render(**doc) def create_iterator(raw_iterator, rank, world_size, limit=None): """ Method for creating a (potentially) sliced and limited iterator from a raw document iterator. Used for splitting data among ranks in multigpu setting or only pulling a sample of documents """ return islice(raw_iterator, rank, limit, world_size) def pad_and_concat( max_length: int, tensors: List[torch.Tensor], padding_side: Literal["right", "left"] = "right", ): """ Method for padding a list of tensors given the maximum tensor length in the batch. Used for batching inputs and continuations in seq2seq models. """ assert ( padding_side == "left" or padding_side == "right" ), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'" for i, tensor in enumerate(tensors): if len(tensor.shape) == 2: tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size tensor_len = tensor.shape[0] if tensor_len < max_length: if padding_side == "right": # right-pad tensors[i] = torch.cat( [ tensor, # [seq] torch.zeros( max_length - tensor_len, dtype=torch.long, device=tensor.device, ), # [padding_length - seq] ], dim=0, ).unsqueeze(0) else: # left-pad tensors[i] = torch.cat( [ torch.zeros( max_length - tensor_len, dtype=torch.long, device=tensor.device, ), # [padding_length - seq] tensor, # [seq] ], dim=0, ).unsqueeze(0) else: tensors[i] = tensor.unsqueeze(0) return torch.cat(tensors, dim=0) def clear_torch_cache() -> None: gc.collect() torch.cuda.empty_cache() def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig""" if isinstance(dtype, str) and dtype != "auto": # Convert `str` args torch dtype: `float16` -> `torch.float16` _torch_dtype = getattr(torch, dtype) else: _torch_dtype = dtype return _torch_dtype # Multi-token stopping criteria class MultiTokenEOSCriteria(transformers.StoppingCriteria): """Criteria to stop on the specified multi-token sequence.""" def __init__( self, sequence: str, tokenizer: transformers.PreTrainedTokenizer, initial_decoder_input_length: int, batch_size: int, ) -> None: self.initial_decoder_input_length = initial_decoder_input_length self.done_tracker = [False] * batch_size self.sequence = sequence self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False) self.sequence_id_len = len(self.sequence_ids) self.tokenizer = tokenizer def __call__(self, input_ids, scores, **kwargs) -> bool: # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][ :, -self.sequence_id_len : ] lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) for i, done in enumerate(self.done_tracker): if not done: self.done_tracker[i] = self.sequence in lookback_tokens_batch[i] return False not in self.done_tracker def stop_sequences_criteria( tokenizer: transformers.PreTrainedTokenizer, stop_sequences: List[str], initial_decoder_input_length: int, batch_size: int, ) -> transformers.StoppingCriteriaList: return transformers.StoppingCriteriaList( [ *[ MultiTokenEOSCriteria( sequence, tokenizer, initial_decoder_input_length, batch_size ) for sequence in stop_sequences ], ] )