utils.py 12.1 KB
Newer Older
1
2
3
4
import collections
import fnmatch
import functools
import importlib.util
5
import inspect
6
7
8
9
10
import logging
import os
import re
import sys
from itertools import islice
11
12
from pathlib import Path
from typing import Any, Callable, List
13

14
import yaml
15
from jinja2 import BaseLoader, Environment, StrictUndefined
sdtblck's avatar
sdtblck committed
16

lintangsutawika's avatar
lintangsutawika committed
17

18
19
20
21
22
logging.basicConfig(
    format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
    datefmt="%Y-%m-%d:%H:%M:%S",
    level=logging.INFO,
)
23
eval_logger = logging.getLogger("lm-eval")
sdtblck's avatar
sdtblck committed
24

25
SPACING = " " * 47
sdtblck's avatar
sdtblck committed
26
27


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def escaped_split(text, sep_char, maxsplit=-1):
    """Split text into a list on occurrences of the given separation
    character `sep_char`. The separation character may be escaped by a
    backslash to avoid splitting at that location.

    The separation character must be a string of size 1.

    If `maxsplit` is given, at most `maxsplit` splits are done (thus,
    the list will have at most `maxsplit + 1` elements). If `maxsplit`
    is not specified or less than 0, then there is no limit on the
    number of splits (all possible splits are made).
    """
    assert (
        len(sep_char) == 1
    ), "separation string must be a single character for escaped splitting"

    if maxsplit == 0:
        return text
    maxsplit = max(0, maxsplit)

    return re.split(r"(?<!\\)" + sep_char, text, maxsplit)


haileyschoelkopf's avatar
haileyschoelkopf committed
51
52
53
54
55
def handle_arg_string(arg):
    if arg.lower() == "true":
        return True
    elif arg.lower() == "false":
        return False
56
57
58
59
60
61
    elif arg.isnumeric():
        return int(arg)
    try:
        return float(arg)
    except ValueError:
        return arg
haileyschoelkopf's avatar
haileyschoelkopf committed
62
63


Jason Phang's avatar
gpt3  
Jason Phang committed
64
65
66
67
68
69
def simple_parse_args_string(args_string):
    """
    Parses something like
        args1=val1,arg2=val2
    Into a dictionary
    """
Jason Phang's avatar
Jason Phang committed
70
    args_string = args_string.strip()
Jason Phang's avatar
gpt3  
Jason Phang committed
71
72
    if not args_string:
        return {}
73
    arg_list = [arg for arg in args_string.split(",") if arg]
haileyschoelkopf's avatar
haileyschoelkopf committed
74
75
76
    args_dict = {
        k: handle_arg_string(v) for k, v in [arg.split("=") for arg in arg_list]
    }
Jason Phang's avatar
gpt3  
Jason Phang committed
77
    return args_dict
Leo Gao's avatar
Leo Gao committed
78

Fabrizio Milo's avatar
Fabrizio Milo committed
79

Leo Gao's avatar
Leo Gao committed
80
81
def join_iters(iters):
    for iter in iters:
Leo Gao's avatar
Leo Gao committed
82
        yield from iter
Leo Gao's avatar
Leo Gao committed
83
84


85
86
87
88
89
def group(arr, fn):
    res = collections.defaultdict(list)

    for ob in arr:
        res[fn(ob)].append(ob)
Fabrizio Milo's avatar
Fabrizio Milo committed
90

91
92
    return list(res.values())

Fabrizio Milo's avatar
Fabrizio Milo committed
93

gakada's avatar
gakada committed
94
95
96
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
97
    if isinstance(patterns, str):
98
99
        patterns = [patterns]

gakada's avatar
gakada committed
100
101
102
103
104
105
106
    task_names = set()
    for pattern in patterns:
        for matching in fnmatch.filter(source_list, pattern):
            task_names.add(matching)
    return sorted(list(task_names))


Leo Gao's avatar
Leo Gao committed
107
108
109
110
def general_detokenize(string):
    string = string.replace(" n't", "n't")
    string = string.replace(" )", ")")
    string = string.replace("( ", "(")
Fabrizio Milo's avatar
Fabrizio Milo committed
111
112
    string = string.replace('" ', '"')
    string = string.replace(' "', '"')
Leo Gao's avatar
Fix  
Leo Gao committed
113
    string = re.sub(r" (['.,])", r"\1", string)
114
115
116
    return string


Jason Phang's avatar
Jason Phang committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len):
    """
    - context_len allows for a rolling window context, allowing each prediction window to potentially
      condition on some context

    :param token_list: list
        List of tokens to be PREDICTED
    :param max_seq_len: int
        max_seq_len of model (or max_seq_len we want to use)
    :param context_len: int
        Amount of desired token context for prediction. Needs to be at least 1.
    :param prefix_token: token
        Dummy token like <eos> so the first token has something to condition on
    :return: generator
        Generator of tuples
            (input_tokens, pred_tokens)
        Note: Score only the last len(pred_tokens) logits of the LM
    """
    assert 1 <= context_len <= max_seq_len
    if not token_list:
        return
    # +1 offset, going from input->preds
    pred_len = max_seq_len - context_len + 1
    predicted = 0

    # Special handling for first window: predict all tokens
    first_seq_len = min(max_seq_len, len(token_list))
Fabrizio Milo's avatar
Fabrizio Milo committed
144
    yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len])
Jason Phang's avatar
Jason Phang committed
145
146
147
148
149
    predicted += first_seq_len

    while predicted < len(token_list):
        window_pred_len = min(len(token_list) - predicted, pred_len)
        window_end = predicted + window_pred_len
Leo Gao's avatar
Leo Gao committed
150

Jason Phang's avatar
Jason Phang committed
151
        yield (
lintangsutawika's avatar
lintangsutawika committed
152
153
            token_list[window_end - max_seq_len - 1 : window_end - 1],
            token_list[window_end - window_pred_len : window_end],
Jason Phang's avatar
Jason Phang committed
154
155
156
        )
        predicted += window_pred_len

Fabrizio Milo's avatar
Fabrizio Milo committed
157

Leo Gao's avatar
Leo Gao committed
158
def make_disjoint_window(pair):
Fabrizio Milo's avatar
Fabrizio Milo committed
159
    """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
Leo Gao's avatar
Leo Gao committed
160
    a, b = pair
161
    return a[: len(a) - (len(b) - 1)], b
Fabrizio Milo's avatar
Fabrizio Milo committed
162

Jason Phang's avatar
Jason Phang committed
163

164
class Reorderer:
baberabb's avatar
baberabb committed
165
166
167
168
169
170
171
    def __init__(self, arr: List[Any], fn: Callable) -> None:
        """Reorder an array according to some function

        Args:
            arr (List[Any]): The initial array
            fn (Callable[[Any], Any]): A function to determine the priority of elements
        """
172
173
174
        self.size = len(arr)
        arr = list(enumerate(arr))
        arr = group(arr, lambda x: fn(x[1]))
175
176
177
        # arr = [([y[0] for y in x], x[0][1]) for x in arr]
        # TODO: overhaul reorderer. It currently grouped requests by content but we don't want this
        arr = [([y[0]], x[0][1]) for x in arr for y in x]
178
179
180
        arr.sort(key=lambda x: fn(x[1]))

        self.arr = arr
Fabrizio Milo's avatar
Fabrizio Milo committed
181

182
    def get_reordered(self):
baberabb's avatar
baberabb committed
183
184
185
186
187
        """Gets the reordered array

        Returns:
            List[Any]: The reordered array
        """
188
        return [x[1] for x in self.arr]
Fabrizio Milo's avatar
Fabrizio Milo committed
189

190
    def get_original(self, newarr):
baberabb's avatar
baberabb committed
191
192
193
194
195
196
197
198
        """Restores the original order of a new array based on the old array's order

        Args:
            newarr (List[Any]): The array to be restored

        Returns:
            List[Any]: The array restored to the original order
        """
199
200
201
202
        res = [None] * self.size
        cov = [False] * self.size

        for (inds, _), v in zip(self.arr, newarr):
Fabrizio Milo's avatar
Fabrizio Milo committed
203
            for ind in inds:
204
205
                res[ind] = v
                cov[ind] = True
Fabrizio Milo's avatar
Fabrizio Milo committed
206

207
        assert all(cov)
Fabrizio Milo's avatar
Fabrizio Milo committed
208

209
210
        return res

Fabrizio Milo's avatar
Fabrizio Milo committed
211

Ethan Smith's avatar
Ethan Smith committed
212
def make_table(result_dict, column: str = "results"):
213
    """Generate table of results."""
214
    from pytablewriter import LatexTableWriter, MarkdownTableWriter
215

lintangsutawika's avatar
lintangsutawika committed
216
    if column == "results":
lintangsutawika's avatar
lintangsutawika committed
217
218
219
        column_name = "Tasks"
    elif column == "groups":
        column_name = "Groups"
lintangsutawika's avatar
lintangsutawika committed
220

lintangsutawika's avatar
lintangsutawika committed
221
    all_headers = [
lintangsutawika's avatar
lintangsutawika committed
222
        column_name,
lintangsutawika's avatar
lintangsutawika committed
223
224
        "Version",
        "Filter",
225
        "n-shot",
lintangsutawika's avatar
lintangsutawika committed
226
227
228
229
230
        "Metric",
        "Value",
        "",
        "Stderr",
    ]
231

lintangsutawika's avatar
lintangsutawika committed
232
233
234
235
236
    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = all_headers
    latex_writer.headers = all_headers

237
238
    values = []

lintangsutawika's avatar
lintangsutawika committed
239
    for k, dic in result_dict[column].items():
240
        version = result_dict["versions"][k]
241
        n = str(result_dict["n-shot"][k])
242
243
244
245

        if "alias" in dic:
            k = dic.pop("alias")

246
247
        for (mf), v in dic.items():
            m, _, f = mf.partition(",")
248
249
250
            if m.endswith("_stderr"):
                continue

251
252
            if m + "_stderr" + "," + f in dic:
                se = dic[m + "_stderr" + "," + f]
253
254
255
                if se != "N/A":
                    se = "%.4f" % se
                values.append([k, version, f, n, m, "%.4f" % v, "±", se])
256
            else:
257
                values.append([k, version, f, n, m, "%.4f" % v, "", ""])
258
259
260
261
262
263
264
265
266
267
268
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

    return md_writer.dumps()


269
270
def positional_deprecated(fn):
    """
Fabrizio Milo's avatar
Fabrizio Milo committed
271
    A decorator to nudge users into passing only keyword args (`kwargs`) to the
272
273
    wrapped function, `fn`.
    """
Fabrizio Milo's avatar
Fabrizio Milo committed
274

275
276
    @functools.wraps(fn)
    def _wrapper(*args, **kwargs):
Fabrizio Milo's avatar
Fabrizio Milo committed
277
278
279
        if len(args) != 1 if inspect.ismethod(fn) else 0:
            print(
                f"WARNING: using {fn.__name__} with positional arguments is "
280
                "deprecated and will be disallowed in a future version of "
Fabrizio Milo's avatar
Fabrizio Milo committed
281
282
                "lm-evaluation-harness!"
            )
283
        return fn(*args, **kwargs)
Fabrizio Milo's avatar
Fabrizio Milo committed
284

285
    return _wrapper
Stephen Hogg's avatar
Stephen Hogg committed
286

Fabrizio Milo's avatar
Fabrizio Milo committed
287

Stephen Hogg's avatar
Stephen Hogg committed
288
@positional_deprecated
289
def find_test_root(start_path: Path) -> Path:
Stephen Hogg's avatar
Stephen Hogg committed
290
291
292
293
294
295
296
    """
    Search upward in the directory tree to a maximum of three layers
    to find and return the package root (containing the 'tests' folder)
    """
    cur_path = start_path.resolve()
    max_layers = 3
    for _ in range(max_layers):
Fabrizio Milo's avatar
Fabrizio Milo committed
297
        if (cur_path / "tests" / "test_version_stable.py").exists():
Stephen Hogg's avatar
Stephen Hogg committed
298
299
300
            return cur_path
        else:
            cur_path = cur_path.parent.resolve()
Fabrizio Milo's avatar
Fabrizio Milo committed
301
302
303
304
    raise FileNotFoundError(
        f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
    )

Stephen Hogg's avatar
Stephen Hogg committed
305
306

@positional_deprecated
307
def run_task_tests(task_list: List[str]):
Stephen Hogg's avatar
Stephen Hogg committed
308
309
310
    """
    Find the package root and run the tests for the given tasks
    """
jon-tow's avatar
jon-tow committed
311
312
    import pytest

313
    package_root = find_test_root(start_path=Path(__file__))
Fabrizio Milo's avatar
Fabrizio Milo committed
314
315
316
317
318
319
320
    task_string = " or ".join(task_list)
    args = [
        f"{package_root}/tests/test_version_stable.py",
        f"--rootdir={package_root}",
        "-k",
        f"{task_string}",
    ]
Stephen Hogg's avatar
Stephen Hogg committed
321
322
323
    sys.path.append(str(package_root))
    pytest_return_val = pytest.main(args)
    if pytest_return_val:
Fabrizio Milo's avatar
Fabrizio Milo committed
324
325
326
        raise ValueError(
            f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
        )
327
328


329
330
331
332
def ignore_constructor(loader, node):
    return node


lintangsutawika's avatar
lintangsutawika committed
333
334
335
336
def import_function(loader, node):
    function_name = loader.construct_scalar(node)
    yaml_path = os.path.dirname(loader.name)

lintangsutawika's avatar
lintangsutawika committed
337
    *module_name, function_name = function_name.split(".")
338
    if isinstance(module_name, list):
lintangsutawika's avatar
lintangsutawika committed
339
340
        module_name = ".".join(module_name)
    module_path = os.path.normpath(os.path.join(yaml_path, "{}.py".format(module_name)))
lintangsutawika's avatar
lintangsutawika committed
341
342
343
344
345
346
347
348

    spec = importlib.util.spec_from_file_location(module_name, module_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)

    function = getattr(module, function_name)
    return function

lintangsutawika's avatar
lintangsutawika committed
349

350
351
352
353
354
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"):
    if mode == "simple":
        constructor_fn = ignore_constructor
    elif mode == "full":
        constructor_fn = import_function
lintangsutawika's avatar
lintangsutawika committed
355

356
357
    # Add the import_function constructor to the YAML loader
    yaml.add_constructor("!function", constructor_fn)
358
359
360
    if yaml_config is None:
        with open(yaml_path, "rb") as file:
            yaml_config = yaml.full_load(file)
lintangsutawika's avatar
lintangsutawika committed
361

lintangsutawika's avatar
lintangsutawika committed
362
363
    if yaml_dir is None:
        yaml_dir = os.path.dirname(yaml_path)
364
365
366
367
368
369
370

    assert yaml_dir is not None

    if "include" in yaml_config:
        include_path = yaml_config["include"]
        del yaml_config["include"]

371
        if isinstance(include_path, str):
372
373
374
375
376
377
378
379
380
381
382
383
384
            include_path = [include_path]

        # Load from the last one first
        include_path.reverse()
        final_yaml_config = {}
        for path in include_path:
            # Assumes that path is a full path.
            # If not found, assume the included yaml
            # is in the same dir as the original yaml
            if not os.path.isfile(path):
                path = os.path.join(yaml_dir, path)

            try:
385
                included_yaml_config = load_yaml_config(yaml_path=path, mode=mode)
386
387
388
389
390
391
392
393
                final_yaml_config.update(included_yaml_config)
            except Exception as ex:
                # If failed to load, ignore
                raise ex

        final_yaml_config.update(yaml_config)
        return final_yaml_config
    return yaml_config
lintangsutawika's avatar
lintangsutawika committed
394
395


Ethan Smith's avatar
Ethan Smith committed
396
def regex_replace(string, pattern, repl, count: int = 0):
397
398
    """Implements the `re.sub` function as a custom Jinja filter."""
    return re.sub(pattern, repl, string, count=count)
lintangsutawika's avatar
lintangsutawika committed
399

lintangsutawika's avatar
lintangsutawika committed
400

401
env = Environment(loader=BaseLoader, undefined=StrictUndefined)
402
env.filters["regex_replace"] = regex_replace
403
404


baberabb's avatar
baberabb committed
405
def apply_template(template: str, doc: dict) -> str:
406
407
    rtemplate = env.from_string(template)
    return rtemplate.render(**doc)
408
409


410
411
412
413
def create_iterator(raw_iterator, rank, world_size, limit=None):
    """
    Method for creating a (potentially) sliced and limited
    iterator from a raw document iterator. Used for splitting data
414
415
416
    among ranks in multigpu setting or only pulling a sample of documents
    """
    return islice(raw_iterator, rank, limit, world_size)
417
418


haileyschoelkopf's avatar
haileyschoelkopf committed
419
# Multi-token stopping criteria
baberabb's avatar
baberabb committed
420
421
422


# from more_itertools