import os import re import collections class ExitCodeError(Exception): pass def sh(x): if os.system(x): raise ExitCodeError() def simple_parse_args_string(args_string): """ Parses something like args1=val1,arg2=val2 Into a dictionary """ args_string = args_string.strip() if not args_string: return {} arg_list = args_string.split(",") args_dict = {} for arg in arg_list: k, v = arg.split("=") args_dict[k] = v return args_dict def join_iters(iters): for iter in iters: yield from iter def chunks(iter, n): arr = [] for x in iter: arr.append(x) if len(arr) == n: yield arr arr = [] if arr: yield arr def group(arr, fn): res = collections.defaultdict(list) for ob in arr: res[fn(ob)].append(ob) return list(res.values()) def general_detokenize(string): string = string.replace(" n't", "n't") string = string.replace(" )", ")") string = string.replace("( ", "(") string = string.replace("\" ", "\"") string = string.replace(" \"", "\"") string = re.sub(r" (['.,])", r"\1", string) return string def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len): """ - context_len allows for a rolling window context, allowing each prediction window to potentially condition on some context :param token_list: list List of tokens to be PREDICTED :param max_seq_len: int max_seq_len of model (or max_seq_len we want to use) :param context_len: int Amount of desired token context for prediction. Needs to be at least 1. :param prefix_token: token Dummy token like so the first token has something to condition on :return: generator Generator of tuples (input_tokens, pred_tokens) Note: Score only the last len(pred_tokens) logits of the LM """ assert 1 <= context_len <= max_seq_len if not token_list: return # +1 offset, going from input->preds pred_len = max_seq_len - context_len + 1 predicted = 0 # Special handling for first window: predict all tokens first_seq_len = min(max_seq_len, len(token_list)) yield ( [prefix_token] + token_list[:first_seq_len - 1], token_list[:first_seq_len] ) predicted += first_seq_len while predicted < len(token_list): window_pred_len = min(len(token_list) - predicted, pred_len) window_end = predicted + window_pred_len yield ( token_list[window_end - max_seq_len - 1:window_end - 1], token_list[window_end - window_pred_len:window_end], ) predicted += window_pred_len def make_disjoint_window(pair): """ Takes output from get_rolling_token_windows and makes the context not overlap with the continuation """ a, b = pair return a[:-(len(b) - 1)], b class Reorderer: def __init__(self, arr, fn): self.size = len(arr) arr = list(enumerate(arr)) arr = group(arr, lambda x: fn(x[1])) arr = [ ([y[0] for y in x], x[0][1]) for x in arr ] arr.sort(key=lambda x: fn(x[1])) self.arr = arr def get_reordered(self): return [x[1] for x in self.arr] def get_original(self, newarr): res = [None] * self.size cov = [False] * self.size for (inds, _), v in zip(self.arr, newarr): for ind in inds: res[ind] = v cov[ind] = True assert all(cov) return res