utils.py 28.7 KB
Newer Older
1
2
3
4
5
import collections
import fnmatch
import functools
import gc
import importlib.util
6
import inspect
7
8
import logging
import os
9
import pathlib
10
import re
11
import subprocess
12
import sys
13
14
import time
from functools import wraps
15
from itertools import islice
Baber Abbasi's avatar
Baber Abbasi committed
16
17
18
19
20
21
22
23
24
25
26
27
from typing import (
    Any,
    Callable,
    Iterable,
    Iterator,
    List,
    Literal,
    Optional,
    Tuple,
    Type,
    Union,
)
28
29

import torch
haileyschoelkopf's avatar
haileyschoelkopf committed
30
import transformers
lintangsutawika's avatar
lintangsutawika committed
31
import numpy as np
sdtblck's avatar
sdtblck committed
32

33
import yaml
34
from jinja2 import BaseLoader, Environment, StrictUndefined
sdtblck's avatar
sdtblck committed
35

lintangsutawika's avatar
lintangsutawika committed
36

37
38
39
40
41
logging.basicConfig(
    format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
    datefmt="%Y-%m-%d:%H:%M:%S",
    level=logging.INFO,
)
42
eval_logger = logging.getLogger("lm-eval")
sdtblck's avatar
sdtblck committed
43

44
SPACING = " " * 47
sdtblck's avatar
sdtblck committed
45
46


47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def escaped_split(text, sep_char, maxsplit=-1):
    """Split text into a list on occurrences of the given separation
    character `sep_char`. The separation character may be escaped by a
    backslash to avoid splitting at that location.

    The separation character must be a string of size 1.

    If `maxsplit` is given, at most `maxsplit` splits are done (thus,
    the list will have at most `maxsplit + 1` elements). If `maxsplit`
    is not specified or less than 0, then there is no limit on the
    number of splits (all possible splits are made).
    """
    assert (
        len(sep_char) == 1
    ), "separation string must be a single character for escaped splitting"

    if maxsplit == 0:
        return text
    maxsplit = max(0, maxsplit)

    return re.split(r"(?<!\\)" + sep_char, text, maxsplit)


haileyschoelkopf's avatar
haileyschoelkopf committed
70
71
72
73
74
def handle_arg_string(arg):
    if arg.lower() == "true":
        return True
    elif arg.lower() == "false":
        return False
75
76
77
78
79
80
    elif arg.isnumeric():
        return int(arg)
    try:
        return float(arg)
    except ValueError:
        return arg
haileyschoelkopf's avatar
haileyschoelkopf committed
81
82


Jason Phang's avatar
gpt3  
Jason Phang committed
83
84
85
86
87
88
def simple_parse_args_string(args_string):
    """
    Parses something like
        args1=val1,arg2=val2
    Into a dictionary
    """
Jason Phang's avatar
Jason Phang committed
89
    args_string = args_string.strip()
Jason Phang's avatar
gpt3  
Jason Phang committed
90
91
    if not args_string:
        return {}
92
    arg_list = [arg for arg in args_string.split(",") if arg]
haileyschoelkopf's avatar
haileyschoelkopf committed
93
94
95
    args_dict = {
        k: handle_arg_string(v) for k, v in [arg.split("=") for arg in arg_list]
    }
Jason Phang's avatar
gpt3  
Jason Phang committed
96
    return args_dict
Leo Gao's avatar
Leo Gao committed
97

Fabrizio Milo's avatar
Fabrizio Milo committed
98

Leo Gao's avatar
Leo Gao committed
99
100
def join_iters(iters):
    for iter in iters:
Leo Gao's avatar
Leo Gao committed
101
        yield from iter
Leo Gao's avatar
Leo Gao committed
102
103


Ethan Smith's avatar
Ethan Smith committed
104
def chunks(iter, n: int = 0, fn=None):
baberabb's avatar
baberabb committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    """
    Divides an iterable into chunks of specified size or based on a given function.
    Useful for batching

    Parameters:
    - iter: The input iterable to be divided into chunks.
    - n: An integer representing the size of each chunk. Default is 0.
    - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.

    Returns:
    An iterator that yields chunks of the input iterable.

    Example usage:
    ```
    data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    for chunk in chunks(data, 3):
        print(chunk)
    ```
    Output:
    ```
    [1, 2, 3]
    [4, 5, 6]
    [7, 8, 9]
    [10]
    ```
    """
Leo Gao's avatar
Leo Gao committed
131
    arr = []
132
    for i, x in enumerate(iter):
Leo Gao's avatar
Leo Gao committed
133
        arr.append(x)
134
        if len(arr) == (fn(i, iter) if fn else n):
Leo Gao's avatar
Leo Gao committed
135
136
            yield arr
            arr = []
Fabrizio Milo's avatar
Fabrizio Milo committed
137
138
139
140

    if arr:
        yield arr

Leo Gao's avatar
Leo Gao committed
141

142
143
144
145
146
def group(arr, fn):
    res = collections.defaultdict(list)

    for ob in arr:
        res[fn(ob)].append(ob)
Fabrizio Milo's avatar
Fabrizio Milo committed
147

148
149
    return list(res.values())

Fabrizio Milo's avatar
Fabrizio Milo committed
150

gakada's avatar
gakada committed
151
class MultiChoice:
Ethan Smith's avatar
Ethan Smith committed
152
    def __init__(self, choices) -> None:
gakada's avatar
gakada committed
153
154
155
        self.choices = choices

    # Simple wildcard support (linux filename patterns)
Ethan Smith's avatar
Ethan Smith committed
156
    def __contains__(self, values) -> bool:
gakada's avatar
gakada committed
157
        for value in values.split(","):
158
            if len(fnmatch.filter(self.choices, value)) == 0:
159
                eval_logger.info("Available tasks to choose:")
160
161
                for choice in self.choices:
                    eval_logger.info(f"  - {choice}")
162
                raise ValueError("'{}' is not in task list".format(value))
gakada's avatar
gakada committed
163
164
        return True

Ethan Smith's avatar
Ethan Smith committed
165
    def __iter__(self) -> Iterator:
gakada's avatar
gakada committed
166
167
168
169
170
171
172
        for choice in self.choices:
            yield choice


# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
173
    if isinstance(patterns, str):
174
175
        patterns = [patterns]

gakada's avatar
gakada committed
176
177
178
179
180
181
182
    task_names = set()
    for pattern in patterns:
        for matching in fnmatch.filter(source_list, pattern):
            task_names.add(matching)
    return sorted(list(task_names))


lintangsutawika's avatar
lintangsutawika committed
183
184
185
186
187
188
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()


Leo Gao's avatar
Leo Gao committed
189
190
191
192
def general_detokenize(string):
    string = string.replace(" n't", "n't")
    string = string.replace(" )", ")")
    string = string.replace("( ", "(")
Fabrizio Milo's avatar
Fabrizio Milo committed
193
194
    string = string.replace('" ', '"')
    string = string.replace(' "', '"')
Leo Gao's avatar
Fix  
Leo Gao committed
195
    string = re.sub(r" (['.,])", r"\1", string)
196
197
198
    return string


Jason Phang's avatar
Jason Phang committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len):
    """
    - context_len allows for a rolling window context, allowing each prediction window to potentially
      condition on some context

    :param token_list: list
        List of tokens to be PREDICTED
    :param max_seq_len: int
        max_seq_len of model (or max_seq_len we want to use)
    :param context_len: int
        Amount of desired token context for prediction. Needs to be at least 1.
    :param prefix_token: token
        Dummy token like <eos> so the first token has something to condition on
    :return: generator
        Generator of tuples
            (input_tokens, pred_tokens)
        Note: Score only the last len(pred_tokens) logits of the LM
    """
    assert 1 <= context_len <= max_seq_len
    if not token_list:
        return
    # +1 offset, going from input->preds
    pred_len = max_seq_len - context_len + 1
    predicted = 0

    # Special handling for first window: predict all tokens
    first_seq_len = min(max_seq_len, len(token_list))
Fabrizio Milo's avatar
Fabrizio Milo committed
226
    yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len])
Jason Phang's avatar
Jason Phang committed
227
228
229
230
231
    predicted += first_seq_len

    while predicted < len(token_list):
        window_pred_len = min(len(token_list) - predicted, pred_len)
        window_end = predicted + window_pred_len
Leo Gao's avatar
Leo Gao committed
232

Jason Phang's avatar
Jason Phang committed
233
        yield (
lintangsutawika's avatar
lintangsutawika committed
234
235
            token_list[window_end - max_seq_len - 1 : window_end - 1],
            token_list[window_end - window_pred_len : window_end],
Jason Phang's avatar
Jason Phang committed
236
237
238
        )
        predicted += window_pred_len

Fabrizio Milo's avatar
Fabrizio Milo committed
239

Leo Gao's avatar
Leo Gao committed
240
def make_disjoint_window(pair):
Fabrizio Milo's avatar
Fabrizio Milo committed
241
    """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
Leo Gao's avatar
Leo Gao committed
242
    a, b = pair
243
    return a[: len(a) - (len(b) - 1)], b
Fabrizio Milo's avatar
Fabrizio Milo committed
244

Jason Phang's avatar
Jason Phang committed
245

246
class Reorderer:
baberabb's avatar
baberabb committed
247
248
249
250
251
252
253
    def __init__(self, arr: List[Any], fn: Callable) -> None:
        """Reorder an array according to some function

        Args:
            arr (List[Any]): The initial array
            fn (Callable[[Any], Any]): A function to determine the priority of elements
        """
254
255
256
        self.size = len(arr)
        arr = list(enumerate(arr))
        arr = group(arr, lambda x: fn(x[1]))
257
258
259
        # arr = [([y[0] for y in x], x[0][1]) for x in arr]
        # TODO: overhaul reorderer. It currently grouped requests by content but we don't want this
        arr = [([y[0]], x[0][1]) for x in arr for y in x]
260
261
262
        arr.sort(key=lambda x: fn(x[1]))

        self.arr = arr
Fabrizio Milo's avatar
Fabrizio Milo committed
263

264
    def get_reordered(self):
baberabb's avatar
baberabb committed
265
266
267
268
269
        """Gets the reordered array

        Returns:
            List[Any]: The reordered array
        """
270
        return [x[1] for x in self.arr]
Fabrizio Milo's avatar
Fabrizio Milo committed
271

272
    def get_original(self, newarr):
baberabb's avatar
baberabb committed
273
274
275
276
277
278
279
280
        """Restores the original order of a new array based on the old array's order

        Args:
            newarr (List[Any]): The array to be restored

        Returns:
            List[Any]: The array restored to the original order
        """
281
282
283
284
        res = [None] * self.size
        cov = [False] * self.size

        for (inds, _), v in zip(self.arr, newarr):
Fabrizio Milo's avatar
Fabrizio Milo committed
285
            for ind in inds:
286
287
                res[ind] = v
                cov[ind] = True
Fabrizio Milo's avatar
Fabrizio Milo committed
288

289
        assert all(cov)
Fabrizio Milo's avatar
Fabrizio Milo committed
290

291
292
        return res

Fabrizio Milo's avatar
Fabrizio Milo committed
293

haileyschoelkopf's avatar
haileyschoelkopf committed
294
295
296
297
298
299
300
class Grouper:
    """
    takes an array `arr` and function `fn` and returns a dictionary
    with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all
    objects in `arr` satisfying `key == fn(ob)`.
    """

Ethan Smith's avatar
Ethan Smith committed
301
    def __init__(self, arr, fn) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        # self.orig_arr = arr
        self.size = len(arr)
        arr = list(enumerate(arr))

        def group_return_dict(arr, fn):
            res = collections.defaultdict(list)

            for ob in arr:
                res[fn(ob)].append(ob)
            return res

        arr = group_return_dict(arr, lambda x: fn(x[1]))

        # self.arr has format Dict[Tuple[int, <entry from orig. arr>]]
        self.arr = arr
        self._grouped = None

    def get_grouped(self):
        # return the contents but not indices for our grouped dict.
        if self._grouped:
            return self._grouped
        grouped = {}
        for key in self.arr.keys():
            # drop the index from each element of self.arr
            grouped[key] = [y[1] for y in self.arr[key]]
        self._grouped = grouped
        return grouped

    def get_original(self, grouped_dict):
        # take in a grouped dictionary with e.g. results for each key listed
        # in the same order as the instances in `self.arr`, and
        # return the results in the same (single list) order as `self.orig_arr`.
        res = [None] * self.size
        cov = [False] * self.size
        # orig = [None] * self.size

        assert grouped_dict.keys() == self.arr.keys()

        for key in grouped_dict.keys():
            for (ind, _), v in zip(self.arr[key], grouped_dict[key]):
                res[ind] = v
                cov[ind] = True
                # orig[ind] = _

        assert all(cov)
        # assert orig == self.orig_arr

        return res


Ethan Smith's avatar
Ethan Smith committed
352
def make_table(result_dict, column: str = "results"):
353
    """Generate table of results."""
354
    from pytablewriter import LatexTableWriter, MarkdownTableWriter
355

lintangsutawika's avatar
lintangsutawika committed
356
    if column == "results":
lintangsutawika's avatar
lintangsutawika committed
357
358
359
        column_name = "Tasks"
    elif column == "groups":
        column_name = "Groups"
lintangsutawika's avatar
lintangsutawika committed
360

lintangsutawika's avatar
lintangsutawika committed
361
    all_headers = [
lintangsutawika's avatar
lintangsutawika committed
362
        column_name,
lintangsutawika's avatar
lintangsutawika committed
363
364
        "Version",
        "Filter",
365
        "n-shot",
lintangsutawika's avatar
lintangsutawika committed
366
367
368
369
370
        "Metric",
        "Value",
        "",
        "Stderr",
    ]
371

lintangsutawika's avatar
lintangsutawika committed
372
373
374
375
376
    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = all_headers
    latex_writer.headers = all_headers

377
378
    values = []

lintangsutawika's avatar
lintangsutawika committed
379
    for k, dic in result_dict[column].items():
380
        version = result_dict["versions"][k]
381
        n = str(result_dict["n-shot"][k])
382
383
384
385

        if "alias" in dic:
            k = dic.pop("alias")

386
387
        for (mf), v in dic.items():
            m, _, f = mf.partition(",")
388
389
390
            if m.endswith("_stderr"):
                continue

391
392
            if m + "_stderr" + "," + f in dic:
                se = dic[m + "_stderr" + "," + f]
393
394
395
                if se != "N/A":
                    se = "%.4f" % se
                values.append([k, version, f, n, m, "%.4f" % v, "±", se])
396
            else:
397
                values.append([k, version, f, n, m, "%.4f" % v, "", ""])
398
399
400
401
402
403
404
405
406
407
408
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

    return md_writer.dumps()


409
410
def positional_deprecated(fn):
    """
Fabrizio Milo's avatar
Fabrizio Milo committed
411
    A decorator to nudge users into passing only keyword args (`kwargs`) to the
412
413
    wrapped function, `fn`.
    """
Fabrizio Milo's avatar
Fabrizio Milo committed
414

415
416
    @functools.wraps(fn)
    def _wrapper(*args, **kwargs):
Fabrizio Milo's avatar
Fabrizio Milo committed
417
418
419
        if len(args) != 1 if inspect.ismethod(fn) else 0:
            print(
                f"WARNING: using {fn.__name__} with positional arguments is "
420
                "deprecated and will be disallowed in a future version of "
Fabrizio Milo's avatar
Fabrizio Milo committed
421
422
                "lm-evaluation-harness!"
            )
423
        return fn(*args, **kwargs)
Fabrizio Milo's avatar
Fabrizio Milo committed
424

425
    return _wrapper
Stephen Hogg's avatar
Stephen Hogg committed
426

Fabrizio Milo's avatar
Fabrizio Milo committed
427

Stephen Hogg's avatar
Stephen Hogg committed
428
429
430
431
432
433
434
435
436
@positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
    """
    Search upward in the directory tree to a maximum of three layers
    to find and return the package root (containing the 'tests' folder)
    """
    cur_path = start_path.resolve()
    max_layers = 3
    for _ in range(max_layers):
Fabrizio Milo's avatar
Fabrizio Milo committed
437
        if (cur_path / "tests" / "test_version_stable.py").exists():
Stephen Hogg's avatar
Stephen Hogg committed
438
439
440
            return cur_path
        else:
            cur_path = cur_path.parent.resolve()
Fabrizio Milo's avatar
Fabrizio Milo committed
441
442
443
444
    raise FileNotFoundError(
        f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
    )

Stephen Hogg's avatar
Stephen Hogg committed
445
446

@positional_deprecated
447
def run_task_tests(task_list: List[str]):
Stephen Hogg's avatar
Stephen Hogg committed
448
449
450
    """
    Find the package root and run the tests for the given tasks
    """
jon-tow's avatar
jon-tow committed
451
452
    import pytest

453
    package_root = find_test_root(start_path=pathlib.Path(__file__))
Fabrizio Milo's avatar
Fabrizio Milo committed
454
455
456
457
458
459
460
    task_string = " or ".join(task_list)
    args = [
        f"{package_root}/tests/test_version_stable.py",
        f"--rootdir={package_root}",
        "-k",
        f"{task_string}",
    ]
Stephen Hogg's avatar
Stephen Hogg committed
461
462
463
    sys.path.append(str(package_root))
    pytest_return_val = pytest.main(args)
    if pytest_return_val:
Fabrizio Milo's avatar
Fabrizio Milo committed
464
465
466
        raise ValueError(
            f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
        )
467
468


469
470
471
472
473
474
def get_git_commit_hash():
    """
    Gets the git commit hash of your current repo (if it exists).
    Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42
    """
    try:
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
475
        git_hash = subprocess.check_output(["git", "describe", "--always"]).strip()
476
        git_hash = git_hash.decode()
477
478
    except subprocess.CalledProcessError or FileNotFoundError:
        # FileNotFoundError occurs when git not installed on system
479
480
481
482
        git_hash = None
    return git_hash


483
484
485
def ignore_constructor(loader, node):
    return node

lintangsutawika's avatar
lintangsutawika committed
486

487
488
489
490
491
492
493
def import_function(loader, node):
    function_name = loader.construct_scalar(node)
    yaml_path = os.path.dirname(loader.name)

    *module_name, function_name = function_name.split(".")
    if isinstance(module_name, list):
        module_name = ".".join(module_name)
lintangsutawika's avatar
lintangsutawika committed
494
    module_path = os.path.normpath(os.path.join(yaml_path, "{}.py".format(module_name)))
lintangsutawika's avatar
lintangsutawika committed
495

496
497
498
    spec = importlib.util.spec_from_file_location(module_name, module_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
lintangsutawika's avatar
lintangsutawika committed
499

500
501
    function = getattr(module, function_name)
    return function
502
503


504
505
506
507
508
def load_yaml_config(mode="simple", yaml_path=None, yaml_config=None, yaml_dir=None):
    if mode == "simple":
        constuctor_fn = ignore_constructor
    elif mode == "full":
        constuctor_fn = import_function
509

510
    # Add the import_function constructor to the YAML loader
511
    yaml.add_constructor("!function", constuctor_fn)
512
513
514
    if yaml_config is None:
        with open(yaml_path, "rb") as file:
            yaml_config = yaml.full_load(file)
lintangsutawika's avatar
lintangsutawika committed
515

lintangsutawika's avatar
lintangsutawika committed
516
517
    if yaml_dir is None:
        yaml_dir = os.path.dirname(yaml_path)
518
519
520
521
522
523
524

    assert yaml_dir is not None

    if "include" in yaml_config:
        include_path = yaml_config["include"]
        del yaml_config["include"]

525
        if isinstance(include_path, str):
526
527
528
529
530
531
532
533
534
535
536
537
538
            include_path = [include_path]

        # Load from the last one first
        include_path.reverse()
        final_yaml_config = {}
        for path in include_path:
            # Assumes that path is a full path.
            # If not found, assume the included yaml
            # is in the same dir as the original yaml
            if not os.path.isfile(path):
                path = os.path.join(yaml_dir, path)

            try:
539
                included_yaml_config = load_yaml_config(mode=mode, yaml_path=path)
540
541
542
543
544
545
546
547
                final_yaml_config.update(included_yaml_config)
            except Exception as ex:
                # If failed to load, ignore
                raise ex

        final_yaml_config.update(yaml_config)
        return final_yaml_config
    return yaml_config
lintangsutawika's avatar
lintangsutawika committed
548
549


Ethan Smith's avatar
Ethan Smith committed
550
def regex_replace(string, pattern, repl, count: int = 0):
551
552
    """Implements the `re.sub` function as a custom Jinja filter."""
    return re.sub(pattern, repl, string, count=count)
lintangsutawika's avatar
lintangsutawika committed
553

lintangsutawika's avatar
lintangsutawika committed
554

555
env = Environment(loader=BaseLoader, undefined=StrictUndefined)
556
env.filters["regex_replace"] = regex_replace
557
558


baberabb's avatar
baberabb committed
559
def apply_template(template: str, doc: dict) -> str:
560
561
    rtemplate = env.from_string(template)
    return rtemplate.render(**doc)
562
563


564
565
566
567
def create_iterator(raw_iterator, rank, world_size, limit=None):
    """
    Method for creating a (potentially) sliced and limited
    iterator from a raw document iterator. Used for splitting data
568
569
570
    among ranks in multigpu setting or only pulling a sample of documents
    """
    return islice(raw_iterator, rank, limit, world_size)
571
572


haileyschoelkopf's avatar
haileyschoelkopf committed
573
574
575
576
577
def pad_and_concat(
    max_length: int,
    tensors: List[torch.Tensor],
    padding_side: Literal["right", "left"] = "right",
):
haileyschoelkopf's avatar
haileyschoelkopf committed
578
579
580
581
    """
    Method for padding a list of tensors given the maximum tensor
    length in the batch. Used for batching inputs and continuations in
    seq2seq models.
lintangsutawika's avatar
lintangsutawika committed
582
    """
haileyschoelkopf's avatar
haileyschoelkopf committed
583
584
585
    assert (
        padding_side == "left" or padding_side == "right"
    ), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
haileyschoelkopf's avatar
haileyschoelkopf committed
586

lintangsutawika's avatar
lintangsutawika committed
587
    for i, tensor in enumerate(tensors):
588
589
        if len(tensor.shape) == 2:
            tensor = tensor.squeeze(0)  # squeeze, in case passed [1, seq] size
lintangsutawika's avatar
lintangsutawika committed
590
591
        tensor_len = tensor.shape[0]
        if tensor_len < max_length:
haileyschoelkopf's avatar
haileyschoelkopf committed
592
593
594
            if padding_side == "right":
                # right-pad
                tensors[i] = torch.cat(
haileyschoelkopf's avatar
haileyschoelkopf committed
595
596
597
598
599
600
601
602
603
604
                    [
                        tensor,  # [seq]
                        torch.zeros(
                            max_length - tensor_len,
                            dtype=torch.long,
                            device=tensor.device,
                        ),  # [padding_length - seq]
                    ],
                    dim=0,
                ).unsqueeze(0)
haileyschoelkopf's avatar
haileyschoelkopf committed
605
606
607
608
            else:
                # left-pad
                tensors[i] = torch.cat(
                    [
609
                        torch.zeros(
haileyschoelkopf's avatar
haileyschoelkopf committed
610
                            max_length - tensor_len,
611
612
                            dtype=torch.long,
                            device=tensor.device,
haileyschoelkopf's avatar
haileyschoelkopf committed
613
                        ),  # [padding_length - seq]
haileyschoelkopf's avatar
haileyschoelkopf committed
614
                        tensor,  # [seq]
haileyschoelkopf's avatar
haileyschoelkopf committed
615
616
617
                    ],
                    dim=0,
                ).unsqueeze(0)
lintangsutawika's avatar
lintangsutawika committed
618
619
620
        else:
            tensors[i] = tensor.unsqueeze(0)

haileyschoelkopf's avatar
haileyschoelkopf committed
621
    return torch.cat(tensors, dim=0)
haileyschoelkopf's avatar
haileyschoelkopf committed
622
623


Ethan Smith's avatar
Ethan Smith committed
624
def clear_torch_cache() -> None:
625
626
    gc.collect()
    torch.cuda.empty_cache()
haileyschoelkopf's avatar
haileyschoelkopf committed
627
628


lintangsutawika's avatar
lintangsutawika committed
629
630
631
632
633
634
635
636
637
638
def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
    """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
    if isinstance(dtype, str) and dtype != "auto":
        # Convert `str` args torch dtype: `float16` -> `torch.float16`
        _torch_dtype = getattr(torch, dtype)
    else:
        _torch_dtype = dtype
    return _torch_dtype


haileyschoelkopf's avatar
haileyschoelkopf committed
639
# Multi-token stopping criteria
haileyschoelkopf's avatar
haileyschoelkopf committed
640
641
642
643
644
645
646
647
648
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
    """Criteria to stop on the specified multi-token sequence."""

    def __init__(
        self,
        sequence: str,
        tokenizer: transformers.PreTrainedTokenizer,
        initial_decoder_input_length: int,
        batch_size: int,
Ethan Smith's avatar
Ethan Smith committed
649
    ) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
650
651
652
653
        self.initial_decoder_input_length = initial_decoder_input_length
        self.done_tracker = [False] * batch_size
        self.sequence = sequence
        self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
654
        # print(sequence, self.sequence_ids)
655
656
657
658
659
660
661
        # we look back for 2 more tokens than it takes to encode our stop sequence
        # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
        # and we don't want to mistakenly not stop a generation because our
        # (string) stop sequence was output in a different tokenization

        # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
        # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
662
        # Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described.
663
        self.sequence_id_len = len(self.sequence_ids) + 2
haileyschoelkopf's avatar
haileyschoelkopf committed
664
665
666
667
        self.tokenizer = tokenizer

    def __call__(self, input_ids, scores, **kwargs) -> bool:
        # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
668
669
670
        lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]

        lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]
haileyschoelkopf's avatar
haileyschoelkopf committed
671
672

        lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
673

haileyschoelkopf's avatar
haileyschoelkopf committed
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
        for i, done in enumerate(self.done_tracker):
            if not done:
                self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
        return False not in self.done_tracker


def stop_sequences_criteria(
    tokenizer: transformers.PreTrainedTokenizer,
    stop_sequences: List[str],
    initial_decoder_input_length: int,
    batch_size: int,
) -> transformers.StoppingCriteriaList:
    return transformers.StoppingCriteriaList(
        [
            *[
                MultiTokenEOSCriteria(
                    sequence, tokenizer, initial_decoder_input_length, batch_size
                )
                for sequence in stop_sequences
            ],
        ]
    )
baberabb's avatar
baberabb committed
696
697
698
699
700
701
702


# from more_itertools
def divide(iterable, n) -> List[Iterator]:
    """Divide the elements from *iterable* into *n* parts, maintaining
    order.

703
        >>> group_1, group_2 = divide([1, 2, 3, 4, 5, 6], 2)
baberabb's avatar
baberabb committed
704
705
706
707
708
709
710
711
        >>> list(group_1)
        [1, 2, 3]
        >>> list(group_2)
        [4, 5, 6]

    If the length of *iterable* is not evenly divisible by *n*, then the
    length of the returned iterables will not be identical:

712
        >>> children = divide([1, 2, 3, 4, 5, 6, 7], 3)
baberabb's avatar
baberabb committed
713
714
715
716
717
718
        >>> [list(c) for c in children]
        [[1, 2, 3], [4, 5], [6, 7]]

    If the length of the iterable is smaller than n, then the last returned
    iterables will be empty:

719
        >>> children = divide([1, 2, 3], 5)
baberabb's avatar
baberabb committed
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
        >>> [list(c) for c in children]
        [[1], [2], [3], [], []]

    This function will exhaust the iterable before returning and may require
    significant storage. If order is not important, see :func:`distribute`,
    which does not first pull the iterable into memory.

    """
    if n < 1:
        raise ValueError("n must be at least 1")

    try:
        iterable[:0]
    except TypeError:
        seq = tuple(iterable)
    else:
        seq = iterable

    q, r = divmod(len(seq), n)

    ret = []
    stop = 0
    for i in range(1, n + 1):
        start = stop
        stop += q + 1 if i <= r else q
        ret.append(iter(seq[start:stop]))

    return ret
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787


def retry_on_specific_exceptions(
    on_exceptions: List[Type[Exception]],
    max_retries: Optional[int] = None,
    backoff_time: float = 3.0,
    backoff_multiplier: float = 1.5,
    on_exception_callback: Optional[Callable[[Exception, float], Any]] = None,
):
    """Retry on an LLM Provider's rate limit error with exponential backoff
    For example, to use for OpenAI, do the following:
    ```
    from openai import RateLimitError

    # Recommend specifying max_retries to avoid infinite loops!
    @retry_on_specific_exceptions([RateLimitError], max_retries=3)
    def completion(...):
        # Wrap OpenAI completion function here
        ...
    ```
    """

    def decorator(func: Callable):
        @wraps(func)
        def wrapper(*args, **kwargs):
            sleep_time = backoff_time
            attempt = 0
            while max_retries is None or attempt < max_retries:
                try:
                    return func(*args, **kwargs)
                except tuple(on_exceptions) as e:
                    if on_exception_callback is not None:
                        on_exception_callback(e, sleep_time)
                    time.sleep(sleep_time)
                    sleep_time *= backoff_multiplier
                    attempt += 1

        return wrapper

    return decorator
Baber Abbasi's avatar
Baber Abbasi committed
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879


class Collator:
    """
    A class for reordering and batching elements of an array.

    This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
    """

    def __init__(
        self,
        arr: List,
        sort_fn: Callable,
        group_fn: Callable = lambda x: x[1],
        grouping: bool = False,
    ) -> None:
        self.grouping = grouping
        self.fn = sort_fn
        self.group_fn = lambda x: group_fn(x[1])  # first index are enumerated indices
        self.reorder_indices: List = []
        self.size = len(arr)
        self.arr_with_indices: Iterable[Any] = tuple(enumerate(arr))  # [indices, (arr)]
        if self.grouping is True:
            self.group_by_index()

    def group_by_index(self) -> None:
        self.arr_with_indices = self.group(
            self.arr_with_indices, fn=self.group_fn, values=False
        )

    def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
        """
        Generates and yields batches from the reordered array.

        Parameters:
        - n (int): The size of each batch. Defaults to 1.
        - batch_fn (Optional[Callable[[int, Iterable], int]]): A function to determine the size of each batch. Defaults to None.

        Yields:
        Iterator: An iterator over batches of reordered elements.
        """
        if self.grouping:
            for (
                key,
                values,
            ) in self.arr_with_indices.items():  # type: ignore
                values = self._reorder(values)
                batch = self.get_chunks(values, n=n, fn=batch_fn)
                yield from batch
        else:
            values = self._reorder(self.arr_with_indices)  # type: ignore
            batch = self.get_chunks(values, n=n, fn=batch_fn)
            yield from batch

    def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> List:
        """
        Reorders the elements in the array based on the sorting function.

        Parameters:
        - arr (Union[List, Tuple[Tuple[int, Any], ...]]): The array or iterable to be reordered.

        Yields:
        List: Yields reordered elements one by one.
        """
        arr = sorted(arr, key=lambda x: self.fn(x[1]))
        self.reorder_indices.extend([x[0] for x in arr])
        yield from [x[1] for x in arr]

    def get_original(self, newarr: List) -> List:
        """
        Restores the original order of elements from the reordered list.

        Parameters:
        - newarr (List): The reordered array.

        Returns:
        List: The array with elements restored to their original order.
        """
        res = [None] * self.size
        cov = [False] * self.size

        for ind, v in zip(self.reorder_indices, newarr):
            res[ind] = v
            cov[ind] = True

        assert all(cov)

        return res

    def __len__(self):
        return self.size

Baber Abbasi's avatar
Baber Abbasi committed
880
881
    @staticmethod
    def group(arr: Iterable, fn: Callable, values: bool = False) -> Iterable:
Baber Abbasi's avatar
Baber Abbasi committed
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
        """
        Groups elements of an iterable based on a provided function.

        Parameters:
        - arr (Iterable): The iterable to be grouped.
        - fn (Callable): The function to determine the grouping.
        - values (bool): If True, returns the values of the group. Defaults to False.

        Returns:
        Iterable: An iterable of grouped elements.
        """
        res = collections.defaultdict(list)
        for ob in arr:
            try:
                hashable_dict = tuple(
Baber Abbasi's avatar
Baber Abbasi committed
897
898
899
900
901
902
903
                    (
                        key,
                        tuple(value)
                        if isinstance(value, collections.abc.Iterable)
                        else value,
                    )
                    for key, value in sorted(fn(ob).items())
Baber Abbasi's avatar
Baber Abbasi committed
904
905
906
907
908
909
910
911
                )
                res[hashable_dict].append(ob)
            except TypeError:
                res[fn(ob)].append(ob)
        if not values:
            return res
        return res.values()

Baber Abbasi's avatar
Baber Abbasi committed
912
913
    @staticmethod
    def get_chunks(_iter, n: int = 0, fn=None):
Baber Abbasi's avatar
Baber Abbasi committed
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
        """
        Divides an iterable into chunks of specified size or based on a given function.
        Useful for batching

        Parameters:
        - iter: The input iterable to be divided into chunks.
        - n: An integer representing the size of each chunk. Default is 0.
        - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.

        Returns:
        An iterator that yields chunks of the input iterable.

        Example usage:
        ```
        data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        for chunk in chunks(data, 3):
            print(chunk)
        ```
        Output:
        ```
        [1, 2, 3]
        [4, 5, 6]
        [7, 8, 9]
        [10]
        ```
        """
        arr = []
941
        _iter = tuple(_iter)
Baber Abbasi's avatar
Baber Abbasi committed
942
        for i, x in enumerate(_iter):
Baber Abbasi's avatar
Baber Abbasi committed
943
            arr.append(x)
Baber Abbasi's avatar
Baber Abbasi committed
944
            if len(arr) == (fn(i, _iter) if fn else n):
Baber Abbasi's avatar
Baber Abbasi committed
945
946
947
948
949
                yield arr
                arr = []

        if arr:
            yield arr