utils.py 28.8 KB
Newer Older
1
2
3
4
5
import collections
import fnmatch
import functools
import gc
import importlib.util
6
import inspect
7
8
import logging
import os
9
import pathlib
10
import re
11
import subprocess
12
import sys
13
14
import time
from functools import wraps
15
from itertools import islice
Baber Abbasi's avatar
Baber Abbasi committed
16
17
18
19
20
21
22
23
24
25
26
27
from typing import (
    Any,
    Callable,
    Iterable,
    Iterator,
    List,
    Literal,
    Optional,
    Tuple,
    Type,
    Union,
)
28
29

import torch
haileyschoelkopf's avatar
haileyschoelkopf committed
30
import transformers
lintangsutawika's avatar
lintangsutawika committed
31
import numpy as np
sdtblck's avatar
sdtblck committed
32

33
import yaml
34
from jinja2 import BaseLoader, Environment, StrictUndefined
sdtblck's avatar
sdtblck committed
35

lintangsutawika's avatar
lintangsutawika committed
36

37
38
39
40
41
logging.basicConfig(
    format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
    datefmt="%Y-%m-%d:%H:%M:%S",
    level=logging.INFO,
)
42
eval_logger = logging.getLogger("lm-eval")
sdtblck's avatar
sdtblck committed
43

44
SPACING = " " * 47
sdtblck's avatar
sdtblck committed
45
46


47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def escaped_split(text, sep_char, maxsplit=-1):
    """Split text into a list on occurrences of the given separation
    character `sep_char`. The separation character may be escaped by a
    backslash to avoid splitting at that location.

    The separation character must be a string of size 1.

    If `maxsplit` is given, at most `maxsplit` splits are done (thus,
    the list will have at most `maxsplit + 1` elements). If `maxsplit`
    is not specified or less than 0, then there is no limit on the
    number of splits (all possible splits are made).
    """
    assert (
        len(sep_char) == 1
    ), "separation string must be a single character for escaped splitting"

    if maxsplit == 0:
        return text
    maxsplit = max(0, maxsplit)

    return re.split(r"(?<!\\)" + sep_char, text, maxsplit)


haileyschoelkopf's avatar
haileyschoelkopf committed
70
71
72
73
74
def handle_arg_string(arg):
    if arg.lower() == "true":
        return True
    elif arg.lower() == "false":
        return False
75
76
77
78
79
80
    elif arg.isnumeric():
        return int(arg)
    try:
        return float(arg)
    except ValueError:
        return arg
haileyschoelkopf's avatar
haileyschoelkopf committed
81
82


Jason Phang's avatar
gpt3  
Jason Phang committed
83
84
85
86
87
88
def simple_parse_args_string(args_string):
    """
    Parses something like
        args1=val1,arg2=val2
    Into a dictionary
    """
Jason Phang's avatar
Jason Phang committed
89
    args_string = args_string.strip()
Jason Phang's avatar
gpt3  
Jason Phang committed
90
91
    if not args_string:
        return {}
92
    arg_list = [arg for arg in args_string.split(",") if arg]
haileyschoelkopf's avatar
haileyschoelkopf committed
93
94
95
    args_dict = {
        k: handle_arg_string(v) for k, v in [arg.split("=") for arg in arg_list]
    }
Jason Phang's avatar
gpt3  
Jason Phang committed
96
    return args_dict
Leo Gao's avatar
Leo Gao committed
97

Fabrizio Milo's avatar
Fabrizio Milo committed
98

Leo Gao's avatar
Leo Gao committed
99
100
def join_iters(iters):
    for iter in iters:
Leo Gao's avatar
Leo Gao committed
101
        yield from iter
Leo Gao's avatar
Leo Gao committed
102
103


Ethan Smith's avatar
Ethan Smith committed
104
def chunks(iter, n: int = 0, fn=None):
baberabb's avatar
baberabb committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    """
    Divides an iterable into chunks of specified size or based on a given function.
    Useful for batching

    Parameters:
    - iter: The input iterable to be divided into chunks.
    - n: An integer representing the size of each chunk. Default is 0.
    - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.

    Returns:
    An iterator that yields chunks of the input iterable.

    Example usage:
    ```
    data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    for chunk in chunks(data, 3):
        print(chunk)
    ```
    Output:
    ```
    [1, 2, 3]
    [4, 5, 6]
    [7, 8, 9]
    [10]
    ```
    """
Leo Gao's avatar
Leo Gao committed
131
    arr = []
132
    for i, x in enumerate(iter):
Leo Gao's avatar
Leo Gao committed
133
        arr.append(x)
134
        if len(arr) == (fn(i, iter) if fn else n):
Leo Gao's avatar
Leo Gao committed
135
136
            yield arr
            arr = []
Fabrizio Milo's avatar
Fabrizio Milo committed
137
138
139
140

    if arr:
        yield arr

Leo Gao's avatar
Leo Gao committed
141

142
143
144
145
146
def group(arr, fn):
    res = collections.defaultdict(list)

    for ob in arr:
        res[fn(ob)].append(ob)
Fabrizio Milo's avatar
Fabrizio Milo committed
147

148
149
    return list(res.values())

Fabrizio Milo's avatar
Fabrizio Milo committed
150

gakada's avatar
gakada committed
151
class MultiChoice:
Ethan Smith's avatar
Ethan Smith committed
152
    def __init__(self, choices) -> None:
gakada's avatar
gakada committed
153
154
155
        self.choices = choices

    # Simple wildcard support (linux filename patterns)
Ethan Smith's avatar
Ethan Smith committed
156
    def __contains__(self, values) -> bool:
gakada's avatar
gakada committed
157
        for value in values.split(","):
158
            if len(fnmatch.filter(self.choices, value)) == 0:
159
                eval_logger.info("Available tasks to choose:")
160
161
                for choice in self.choices:
                    eval_logger.info(f"  - {choice}")
162
                raise ValueError("'{}' is not in task list".format(value))
gakada's avatar
gakada committed
163
164
        return True

Ethan Smith's avatar
Ethan Smith committed
165
    def __iter__(self) -> Iterator:
gakada's avatar
gakada committed
166
167
168
169
170
171
172
        for choice in self.choices:
            yield choice


# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
173
    if isinstance(patterns, str):
174
175
        patterns = [patterns]

gakada's avatar
gakada committed
176
177
178
179
180
181
182
    task_names = set()
    for pattern in patterns:
        for matching in fnmatch.filter(source_list, pattern):
            task_names.add(matching)
    return sorted(list(task_names))


lintangsutawika's avatar
lintangsutawika committed
183
184
185
186
187
188
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()


Leo Gao's avatar
Leo Gao committed
189
190
191
192
def general_detokenize(string):
    string = string.replace(" n't", "n't")
    string = string.replace(" )", ")")
    string = string.replace("( ", "(")
Fabrizio Milo's avatar
Fabrizio Milo committed
193
194
    string = string.replace('" ', '"')
    string = string.replace(' "', '"')
Leo Gao's avatar
Fix  
Leo Gao committed
195
    string = re.sub(r" (['.,])", r"\1", string)
196
197
198
    return string


Jason Phang's avatar
Jason Phang committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len):
    """
    - context_len allows for a rolling window context, allowing each prediction window to potentially
      condition on some context

    :param token_list: list
        List of tokens to be PREDICTED
    :param max_seq_len: int
        max_seq_len of model (or max_seq_len we want to use)
    :param context_len: int
        Amount of desired token context for prediction. Needs to be at least 1.
    :param prefix_token: token
        Dummy token like <eos> so the first token has something to condition on
    :return: generator
        Generator of tuples
            (input_tokens, pred_tokens)
        Note: Score only the last len(pred_tokens) logits of the LM
    """
    assert 1 <= context_len <= max_seq_len
    if not token_list:
        return
    # +1 offset, going from input->preds
    pred_len = max_seq_len - context_len + 1
    predicted = 0

    # Special handling for first window: predict all tokens
    first_seq_len = min(max_seq_len, len(token_list))
Fabrizio Milo's avatar
Fabrizio Milo committed
226
    yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len])
Jason Phang's avatar
Jason Phang committed
227
228
229
230
231
    predicted += first_seq_len

    while predicted < len(token_list):
        window_pred_len = min(len(token_list) - predicted, pred_len)
        window_end = predicted + window_pred_len
Leo Gao's avatar
Leo Gao committed
232

Jason Phang's avatar
Jason Phang committed
233
        yield (
lintangsutawika's avatar
lintangsutawika committed
234
235
            token_list[window_end - max_seq_len - 1 : window_end - 1],
            token_list[window_end - window_pred_len : window_end],
Jason Phang's avatar
Jason Phang committed
236
237
238
        )
        predicted += window_pred_len

Fabrizio Milo's avatar
Fabrizio Milo committed
239

Leo Gao's avatar
Leo Gao committed
240
def make_disjoint_window(pair):
Fabrizio Milo's avatar
Fabrizio Milo committed
241
    """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
Leo Gao's avatar
Leo Gao committed
242
    a, b = pair
243
    return a[: len(a) - (len(b) - 1)], b
Fabrizio Milo's avatar
Fabrizio Milo committed
244

Jason Phang's avatar
Jason Phang committed
245

246
class Reorderer:
baberabb's avatar
baberabb committed
247
248
249
250
251
252
253
    def __init__(self, arr: List[Any], fn: Callable) -> None:
        """Reorder an array according to some function

        Args:
            arr (List[Any]): The initial array
            fn (Callable[[Any], Any]): A function to determine the priority of elements
        """
254
255
256
        self.size = len(arr)
        arr = list(enumerate(arr))
        arr = group(arr, lambda x: fn(x[1]))
257
258
259
        # arr = [([y[0] for y in x], x[0][1]) for x in arr]
        # TODO: overhaul reorderer. It currently grouped requests by content but we don't want this
        arr = [([y[0]], x[0][1]) for x in arr for y in x]
260
261
262
        arr.sort(key=lambda x: fn(x[1]))

        self.arr = arr
Fabrizio Milo's avatar
Fabrizio Milo committed
263

264
    def get_reordered(self):
baberabb's avatar
baberabb committed
265
266
267
268
269
        """Gets the reordered array

        Returns:
            List[Any]: The reordered array
        """
270
        return [x[1] for x in self.arr]
Fabrizio Milo's avatar
Fabrizio Milo committed
271

272
    def get_original(self, newarr):
baberabb's avatar
baberabb committed
273
274
275
276
277
278
279
280
        """Restores the original order of a new array based on the old array's order

        Args:
            newarr (List[Any]): The array to be restored

        Returns:
            List[Any]: The array restored to the original order
        """
281
282
283
284
        res = [None] * self.size
        cov = [False] * self.size

        for (inds, _), v in zip(self.arr, newarr):
Fabrizio Milo's avatar
Fabrizio Milo committed
285
            for ind in inds:
286
287
                res[ind] = v
                cov[ind] = True
Fabrizio Milo's avatar
Fabrizio Milo committed
288

289
        assert all(cov)
Fabrizio Milo's avatar
Fabrizio Milo committed
290

291
292
        return res

Fabrizio Milo's avatar
Fabrizio Milo committed
293

haileyschoelkopf's avatar
haileyschoelkopf committed
294
295
296
297
298
299
300
class Grouper:
    """
    takes an array `arr` and function `fn` and returns a dictionary
    with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all
    objects in `arr` satisfying `key == fn(ob)`.
    """

Ethan Smith's avatar
Ethan Smith committed
301
    def __init__(self, arr, fn) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        # self.orig_arr = arr
        self.size = len(arr)
        arr = list(enumerate(arr))

        def group_return_dict(arr, fn):
            res = collections.defaultdict(list)

            for ob in arr:
                res[fn(ob)].append(ob)
            return res

        arr = group_return_dict(arr, lambda x: fn(x[1]))

        # self.arr has format Dict[Tuple[int, <entry from orig. arr>]]
        self.arr = arr
        self._grouped = None

    def get_grouped(self):
        # return the contents but not indices for our grouped dict.
        if self._grouped:
            return self._grouped
        grouped = {}
        for key in self.arr.keys():
            # drop the index from each element of self.arr
            grouped[key] = [y[1] for y in self.arr[key]]
        self._grouped = grouped
        return grouped

    def get_original(self, grouped_dict):
        # take in a grouped dictionary with e.g. results for each key listed
        # in the same order as the instances in `self.arr`, and
        # return the results in the same (single list) order as `self.orig_arr`.
        res = [None] * self.size
        cov = [False] * self.size
        # orig = [None] * self.size

        assert grouped_dict.keys() == self.arr.keys()

        for key in grouped_dict.keys():
            for (ind, _), v in zip(self.arr[key], grouped_dict[key]):
                res[ind] = v
                cov[ind] = True
                # orig[ind] = _

        assert all(cov)
        # assert orig == self.orig_arr

        return res


Ethan Smith's avatar
Ethan Smith committed
352
def make_table(result_dict, column: str = "results"):
353
    """Generate table of results."""
354
    from pytablewriter import LatexTableWriter, MarkdownTableWriter
355

lintangsutawika's avatar
lintangsutawika committed
356
    if column == "results":
lintangsutawika's avatar
lintangsutawika committed
357
358
359
        column_name = "Tasks"
    elif column == "groups":
        column_name = "Groups"
lintangsutawika's avatar
lintangsutawika committed
360

lintangsutawika's avatar
lintangsutawika committed
361
    all_headers = [
lintangsutawika's avatar
lintangsutawika committed
362
        column_name,
lintangsutawika's avatar
lintangsutawika committed
363
364
        "Version",
        "Filter",
365
        "n-shot",
lintangsutawika's avatar
lintangsutawika committed
366
367
368
369
370
        "Metric",
        "Value",
        "",
        "Stderr",
    ]
371

lintangsutawika's avatar
lintangsutawika committed
372
373
374
375
376
    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = all_headers
    latex_writer.headers = all_headers

377
378
    values = []

lintangsutawika's avatar
lintangsutawika committed
379
    for k, dic in result_dict[column].items():
380
        version = result_dict["versions"][k]
381
        n = str(result_dict["n-shot"][k])
382
383
384
385

        if "alias" in dic:
            k = dic.pop("alias")

386
387
        for (mf), v in dic.items():
            m, _, f = mf.partition(",")
388
389
390
            if m.endswith("_stderr"):
                continue

391
392
            if m + "_stderr" + "," + f in dic:
                se = dic[m + "_stderr" + "," + f]
393
394
395
                if se != "N/A":
                    se = "%.4f" % se
                values.append([k, version, f, n, m, "%.4f" % v, "±", se])
396
            else:
397
                values.append([k, version, f, n, m, "%.4f" % v, "", ""])
398
399
400
401
402
403
404
405
406
407
408
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

    return md_writer.dumps()


409
410
def positional_deprecated(fn):
    """
Fabrizio Milo's avatar
Fabrizio Milo committed
411
    A decorator to nudge users into passing only keyword args (`kwargs`) to the
412
413
    wrapped function, `fn`.
    """
Fabrizio Milo's avatar
Fabrizio Milo committed
414

415
416
    @functools.wraps(fn)
    def _wrapper(*args, **kwargs):
Fabrizio Milo's avatar
Fabrizio Milo committed
417
418
419
        if len(args) != 1 if inspect.ismethod(fn) else 0:
            print(
                f"WARNING: using {fn.__name__} with positional arguments is "
420
                "deprecated and will be disallowed in a future version of "
Fabrizio Milo's avatar
Fabrizio Milo committed
421
422
                "lm-evaluation-harness!"
            )
423
        return fn(*args, **kwargs)
Fabrizio Milo's avatar
Fabrizio Milo committed
424

425
    return _wrapper
Stephen Hogg's avatar
Stephen Hogg committed
426

Fabrizio Milo's avatar
Fabrizio Milo committed
427

Stephen Hogg's avatar
Stephen Hogg committed
428
429
430
431
432
433
434
435
436
@positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
    """
    Search upward in the directory tree to a maximum of three layers
    to find and return the package root (containing the 'tests' folder)
    """
    cur_path = start_path.resolve()
    max_layers = 3
    for _ in range(max_layers):
Fabrizio Milo's avatar
Fabrizio Milo committed
437
        if (cur_path / "tests" / "test_version_stable.py").exists():
Stephen Hogg's avatar
Stephen Hogg committed
438
439
440
            return cur_path
        else:
            cur_path = cur_path.parent.resolve()
Fabrizio Milo's avatar
Fabrizio Milo committed
441
442
443
444
    raise FileNotFoundError(
        f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
    )

Stephen Hogg's avatar
Stephen Hogg committed
445
446

@positional_deprecated
447
def run_task_tests(task_list: List[str]):
Stephen Hogg's avatar
Stephen Hogg committed
448
449
450
    """
    Find the package root and run the tests for the given tasks
    """
jon-tow's avatar
jon-tow committed
451
452
    import pytest

453
    package_root = find_test_root(start_path=pathlib.Path(__file__))
Fabrizio Milo's avatar
Fabrizio Milo committed
454
455
456
457
458
459
460
    task_string = " or ".join(task_list)
    args = [
        f"{package_root}/tests/test_version_stable.py",
        f"--rootdir={package_root}",
        "-k",
        f"{task_string}",
    ]
Stephen Hogg's avatar
Stephen Hogg committed
461
462
463
    sys.path.append(str(package_root))
    pytest_return_val = pytest.main(args)
    if pytest_return_val:
Fabrizio Milo's avatar
Fabrizio Milo committed
464
465
466
        raise ValueError(
            f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
        )
467
468


469
470
471
472
473
474
def get_git_commit_hash():
    """
    Gets the git commit hash of your current repo (if it exists).
    Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42
    """
    try:
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
475
        git_hash = subprocess.check_output(["git", "describe", "--always"]).strip()
476
        git_hash = git_hash.decode()
477
478
    except subprocess.CalledProcessError or FileNotFoundError:
        # FileNotFoundError occurs when git not installed on system
479
480
481
482
        git_hash = None
    return git_hash


lintangsutawika's avatar
lintangsutawika committed
483
484
485
486
def import_function(loader, node):
    function_name = loader.construct_scalar(node)
    yaml_path = os.path.dirname(loader.name)

lintangsutawika's avatar
lintangsutawika committed
487
    *module_name, function_name = function_name.split(".")
488
    if isinstance(module_name, list):
lintangsutawika's avatar
lintangsutawika committed
489
490
        module_name = ".".join(module_name)
    module_path = os.path.normpath(os.path.join(yaml_path, "{}.py".format(module_name)))
lintangsutawika's avatar
lintangsutawika committed
491
492
493
494
495
496
497
498

    spec = importlib.util.spec_from_file_location(module_name, module_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)

    function = getattr(module, function_name)
    return function

lintangsutawika's avatar
lintangsutawika committed
499

500
501
502
503
504
505
506
507
508
def ignore_constructor(loader, node):
    return node


def simple_load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
    yaml.add_constructor("!function", ignore_constructor)
    with open(yaml_path, "rb") as file:
        yaml_config = yaml.full_load(file)
    return yaml_config
lintangsutawika's avatar
lintangsutawika committed
509
510


511
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
512
513
    # Add the import_function constructor to the YAML loader
    yaml.add_constructor("!function", import_function)
514
515
516
    if yaml_config is None:
        with open(yaml_path, "rb") as file:
            yaml_config = yaml.full_load(file)
lintangsutawika's avatar
lintangsutawika committed
517

lintangsutawika's avatar
lintangsutawika committed
518
519
    if yaml_dir is None:
        yaml_dir = os.path.dirname(yaml_path)
520
521
522
523
524
525
526

    assert yaml_dir is not None

    if "include" in yaml_config:
        include_path = yaml_config["include"]
        del yaml_config["include"]

527
        if isinstance(include_path, str):
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
            include_path = [include_path]

        # Load from the last one first
        include_path.reverse()
        final_yaml_config = {}
        for path in include_path:
            # Assumes that path is a full path.
            # If not found, assume the included yaml
            # is in the same dir as the original yaml
            if not os.path.isfile(path):
                path = os.path.join(yaml_dir, path)

            try:
                included_yaml_config = load_yaml_config(path)
                final_yaml_config.update(included_yaml_config)
            except Exception as ex:
                # If failed to load, ignore
                raise ex

        final_yaml_config.update(yaml_config)
        return final_yaml_config
    return yaml_config
lintangsutawika's avatar
lintangsutawika committed
550
551


Ethan Smith's avatar
Ethan Smith committed
552
def regex_replace(string, pattern, repl, count: int = 0):
553
554
    """Implements the `re.sub` function as a custom Jinja filter."""
    return re.sub(pattern, repl, string, count=count)
lintangsutawika's avatar
lintangsutawika committed
555

lintangsutawika's avatar
lintangsutawika committed
556

557
env = Environment(loader=BaseLoader, undefined=StrictUndefined)
558
env.filters["regex_replace"] = regex_replace
559
560


baberabb's avatar
baberabb committed
561
def apply_template(template: str, doc: dict) -> str:
562
563
    rtemplate = env.from_string(template)
    return rtemplate.render(**doc)
564
565


566
567
568
569
def create_iterator(raw_iterator, rank, world_size, limit=None):
    """
    Method for creating a (potentially) sliced and limited
    iterator from a raw document iterator. Used for splitting data
570
571
572
    among ranks in multigpu setting or only pulling a sample of documents
    """
    return islice(raw_iterator, rank, limit, world_size)
573
574


haileyschoelkopf's avatar
haileyschoelkopf committed
575
576
577
578
579
def pad_and_concat(
    max_length: int,
    tensors: List[torch.Tensor],
    padding_side: Literal["right", "left"] = "right",
):
haileyschoelkopf's avatar
haileyschoelkopf committed
580
581
582
583
    """
    Method for padding a list of tensors given the maximum tensor
    length in the batch. Used for batching inputs and continuations in
    seq2seq models.
lintangsutawika's avatar
lintangsutawika committed
584
    """
haileyschoelkopf's avatar
haileyschoelkopf committed
585
586
587
    assert (
        padding_side == "left" or padding_side == "right"
    ), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
haileyschoelkopf's avatar
haileyschoelkopf committed
588

lintangsutawika's avatar
lintangsutawika committed
589
    for i, tensor in enumerate(tensors):
590
591
        if len(tensor.shape) == 2:
            tensor = tensor.squeeze(0)  # squeeze, in case passed [1, seq] size
lintangsutawika's avatar
lintangsutawika committed
592
593
        tensor_len = tensor.shape[0]
        if tensor_len < max_length:
haileyschoelkopf's avatar
haileyschoelkopf committed
594
595
596
            if padding_side == "right":
                # right-pad
                tensors[i] = torch.cat(
haileyschoelkopf's avatar
haileyschoelkopf committed
597
598
599
600
601
602
603
604
605
606
                    [
                        tensor,  # [seq]
                        torch.zeros(
                            max_length - tensor_len,
                            dtype=torch.long,
                            device=tensor.device,
                        ),  # [padding_length - seq]
                    ],
                    dim=0,
                ).unsqueeze(0)
haileyschoelkopf's avatar
haileyschoelkopf committed
607
608
609
610
            else:
                # left-pad
                tensors[i] = torch.cat(
                    [
611
                        torch.zeros(
haileyschoelkopf's avatar
haileyschoelkopf committed
612
                            max_length - tensor_len,
613
614
                            dtype=torch.long,
                            device=tensor.device,
haileyschoelkopf's avatar
haileyschoelkopf committed
615
                        ),  # [padding_length - seq]
haileyschoelkopf's avatar
haileyschoelkopf committed
616
                        tensor,  # [seq]
haileyschoelkopf's avatar
haileyschoelkopf committed
617
618
619
                    ],
                    dim=0,
                ).unsqueeze(0)
lintangsutawika's avatar
lintangsutawika committed
620
621
622
        else:
            tensors[i] = tensor.unsqueeze(0)

haileyschoelkopf's avatar
haileyschoelkopf committed
623
    return torch.cat(tensors, dim=0)
haileyschoelkopf's avatar
haileyschoelkopf committed
624
625


Ethan Smith's avatar
Ethan Smith committed
626
def clear_torch_cache() -> None:
627
628
    gc.collect()
    torch.cuda.empty_cache()
haileyschoelkopf's avatar
haileyschoelkopf committed
629
630


lintangsutawika's avatar
lintangsutawika committed
631
632
633
634
635
636
637
638
639
640
def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
    """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
    if isinstance(dtype, str) and dtype != "auto":
        # Convert `str` args torch dtype: `float16` -> `torch.float16`
        _torch_dtype = getattr(torch, dtype)
    else:
        _torch_dtype = dtype
    return _torch_dtype


haileyschoelkopf's avatar
haileyschoelkopf committed
641
# Multi-token stopping criteria
haileyschoelkopf's avatar
haileyschoelkopf committed
642
643
644
645
646
647
648
649
650
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
    """Criteria to stop on the specified multi-token sequence."""

    def __init__(
        self,
        sequence: str,
        tokenizer: transformers.PreTrainedTokenizer,
        initial_decoder_input_length: int,
        batch_size: int,
Ethan Smith's avatar
Ethan Smith committed
651
    ) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
652
653
654
655
        self.initial_decoder_input_length = initial_decoder_input_length
        self.done_tracker = [False] * batch_size
        self.sequence = sequence
        self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
656
        # print(sequence, self.sequence_ids)
657
658
659
660
661
662
663
        # we look back for 2 more tokens than it takes to encode our stop sequence
        # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
        # and we don't want to mistakenly not stop a generation because our
        # (string) stop sequence was output in a different tokenization

        # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
        # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
664
        # Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described.
665
        self.sequence_id_len = len(self.sequence_ids) + 2
haileyschoelkopf's avatar
haileyschoelkopf committed
666
667
668
669
        self.tokenizer = tokenizer

    def __call__(self, input_ids, scores, **kwargs) -> bool:
        # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
670
671
672
        lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]

        lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]
haileyschoelkopf's avatar
haileyschoelkopf committed
673
674

        lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
675

haileyschoelkopf's avatar
haileyschoelkopf committed
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
        for i, done in enumerate(self.done_tracker):
            if not done:
                self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
        return False not in self.done_tracker


def stop_sequences_criteria(
    tokenizer: transformers.PreTrainedTokenizer,
    stop_sequences: List[str],
    initial_decoder_input_length: int,
    batch_size: int,
) -> transformers.StoppingCriteriaList:
    return transformers.StoppingCriteriaList(
        [
            *[
                MultiTokenEOSCriteria(
                    sequence, tokenizer, initial_decoder_input_length, batch_size
                )
                for sequence in stop_sequences
            ],
        ]
    )
baberabb's avatar
baberabb committed
698
699
700
701
702
703
704


# from more_itertools
def divide(iterable, n) -> List[Iterator]:
    """Divide the elements from *iterable* into *n* parts, maintaining
    order.

705
        >>> group_1, group_2 = divide([1, 2, 3, 4, 5, 6], 2)
baberabb's avatar
baberabb committed
706
707
708
709
710
711
712
713
        >>> list(group_1)
        [1, 2, 3]
        >>> list(group_2)
        [4, 5, 6]

    If the length of *iterable* is not evenly divisible by *n*, then the
    length of the returned iterables will not be identical:

714
        >>> children = divide([1, 2, 3, 4, 5, 6, 7], 3)
baberabb's avatar
baberabb committed
715
716
717
718
719
720
        >>> [list(c) for c in children]
        [[1, 2, 3], [4, 5], [6, 7]]

    If the length of the iterable is smaller than n, then the last returned
    iterables will be empty:

721
        >>> children = divide([1, 2, 3], 5)
baberabb's avatar
baberabb committed
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
        >>> [list(c) for c in children]
        [[1], [2], [3], [], []]

    This function will exhaust the iterable before returning and may require
    significant storage. If order is not important, see :func:`distribute`,
    which does not first pull the iterable into memory.

    """
    if n < 1:
        raise ValueError("n must be at least 1")

    try:
        iterable[:0]
    except TypeError:
        seq = tuple(iterable)
    else:
        seq = iterable

    q, r = divmod(len(seq), n)

    ret = []
    stop = 0
    for i in range(1, n + 1):
        start = stop
        stop += q + 1 if i <= r else q
        ret.append(iter(seq[start:stop]))

    return ret
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789


def retry_on_specific_exceptions(
    on_exceptions: List[Type[Exception]],
    max_retries: Optional[int] = None,
    backoff_time: float = 3.0,
    backoff_multiplier: float = 1.5,
    on_exception_callback: Optional[Callable[[Exception, float], Any]] = None,
):
    """Retry on an LLM Provider's rate limit error with exponential backoff
    For example, to use for OpenAI, do the following:
    ```
    from openai import RateLimitError

    # Recommend specifying max_retries to avoid infinite loops!
    @retry_on_specific_exceptions([RateLimitError], max_retries=3)
    def completion(...):
        # Wrap OpenAI completion function here
        ...
    ```
    """

    def decorator(func: Callable):
        @wraps(func)
        def wrapper(*args, **kwargs):
            sleep_time = backoff_time
            attempt = 0
            while max_retries is None or attempt < max_retries:
                try:
                    return func(*args, **kwargs)
                except tuple(on_exceptions) as e:
                    if on_exception_callback is not None:
                        on_exception_callback(e, sleep_time)
                    time.sleep(sleep_time)
                    sleep_time *= backoff_multiplier
                    attempt += 1

        return wrapper

    return decorator
Baber Abbasi's avatar
Baber Abbasi committed
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881


class Collator:
    """
    A class for reordering and batching elements of an array.

    This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
    """

    def __init__(
        self,
        arr: List,
        sort_fn: Callable,
        group_fn: Callable = lambda x: x[1],
        grouping: bool = False,
    ) -> None:
        self.grouping = grouping
        self.fn = sort_fn
        self.group_fn = lambda x: group_fn(x[1])  # first index are enumerated indices
        self.reorder_indices: List = []
        self.size = len(arr)
        self.arr_with_indices: Iterable[Any] = tuple(enumerate(arr))  # [indices, (arr)]
        if self.grouping is True:
            self.group_by_index()

    def group_by_index(self) -> None:
        self.arr_with_indices = self.group(
            self.arr_with_indices, fn=self.group_fn, values=False
        )

    def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
        """
        Generates and yields batches from the reordered array.

        Parameters:
        - n (int): The size of each batch. Defaults to 1.
        - batch_fn (Optional[Callable[[int, Iterable], int]]): A function to determine the size of each batch. Defaults to None.

        Yields:
        Iterator: An iterator over batches of reordered elements.
        """
        if self.grouping:
            for (
                key,
                values,
            ) in self.arr_with_indices.items():  # type: ignore
                values = self._reorder(values)
                batch = self.get_chunks(values, n=n, fn=batch_fn)
                yield from batch
        else:
            values = self._reorder(self.arr_with_indices)  # type: ignore
            batch = self.get_chunks(values, n=n, fn=batch_fn)
            yield from batch

    def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> List:
        """
        Reorders the elements in the array based on the sorting function.

        Parameters:
        - arr (Union[List, Tuple[Tuple[int, Any], ...]]): The array or iterable to be reordered.

        Yields:
        List: Yields reordered elements one by one.
        """
        arr = sorted(arr, key=lambda x: self.fn(x[1]))
        self.reorder_indices.extend([x[0] for x in arr])
        yield from [x[1] for x in arr]

    def get_original(self, newarr: List) -> List:
        """
        Restores the original order of elements from the reordered list.

        Parameters:
        - newarr (List): The reordered array.

        Returns:
        List: The array with elements restored to their original order.
        """
        res = [None] * self.size
        cov = [False] * self.size

        for ind, v in zip(self.reorder_indices, newarr):
            res[ind] = v
            cov[ind] = True

        assert all(cov)

        return res

    def __len__(self):
        return self.size

Baber Abbasi's avatar
Baber Abbasi committed
882
883
    @staticmethod
    def group(arr: Iterable, fn: Callable, values: bool = False) -> Iterable:
Baber Abbasi's avatar
Baber Abbasi committed
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
        """
        Groups elements of an iterable based on a provided function.

        Parameters:
        - arr (Iterable): The iterable to be grouped.
        - fn (Callable): The function to determine the grouping.
        - values (bool): If True, returns the values of the group. Defaults to False.

        Returns:
        Iterable: An iterable of grouped elements.
        """
        res = collections.defaultdict(list)
        for ob in arr:
            try:
                hashable_dict = tuple(
Baber Abbasi's avatar
Baber Abbasi committed
899
900
901
902
903
904
905
                    (
                        key,
                        tuple(value)
                        if isinstance(value, collections.abc.Iterable)
                        else value,
                    )
                    for key, value in sorted(fn(ob).items())
Baber Abbasi's avatar
Baber Abbasi committed
906
907
908
909
910
911
912
913
                )
                res[hashable_dict].append(ob)
            except TypeError:
                res[fn(ob)].append(ob)
        if not values:
            return res
        return res.values()

Baber Abbasi's avatar
Baber Abbasi committed
914
915
    @staticmethod
    def get_chunks(_iter, n: int = 0, fn=None):
Baber Abbasi's avatar
Baber Abbasi committed
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
        """
        Divides an iterable into chunks of specified size or based on a given function.
        Useful for batching

        Parameters:
        - iter: The input iterable to be divided into chunks.
        - n: An integer representing the size of each chunk. Default is 0.
        - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.

        Returns:
        An iterator that yields chunks of the input iterable.

        Example usage:
        ```
        data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        for chunk in chunks(data, 3):
            print(chunk)
        ```
        Output:
        ```
        [1, 2, 3]
        [4, 5, 6]
        [7, 8, 9]
        [10]
        ```
        """
        arr = []
943
        _iter = tuple(_iter)
Baber Abbasi's avatar
Baber Abbasi committed
944
        for i, x in enumerate(_iter):
Baber Abbasi's avatar
Baber Abbasi committed
945
            arr.append(x)
Baber Abbasi's avatar
Baber Abbasi committed
946
            if len(arr) == (fn(i, _iter) if fn else n):
Baber Abbasi's avatar
Baber Abbasi committed
947
948
949
950
951
                yield arr
                arr = []

        if arr:
            yield arr