utils.py 28.1 KB
Newer Older
1
2
3
4
5
import collections
import fnmatch
import functools
import gc
import importlib.util
6
import inspect
7
8
import logging
import os
9
import pathlib
10
import re
11
import subprocess
12
import sys
13
14
import time
from functools import wraps
15
from itertools import islice
Baber Abbasi's avatar
Baber Abbasi committed
16
17
18
19
20
21
22
23
24
25
26
27
from typing import (
    Any,
    Callable,
    Iterable,
    Iterator,
    List,
    Literal,
    Optional,
    Tuple,
    Type,
    Union,
)
28
29

import torch
haileyschoelkopf's avatar
haileyschoelkopf committed
30
import transformers
31
import yaml
32
from jinja2 import BaseLoader, Environment, StrictUndefined
sdtblck's avatar
sdtblck committed
33

lintangsutawika's avatar
lintangsutawika committed
34

35
36
37
38
39
logging.basicConfig(
    format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
    datefmt="%Y-%m-%d:%H:%M:%S",
    level=logging.INFO,
)
40
eval_logger = logging.getLogger("lm-eval")
sdtblck's avatar
sdtblck committed
41

42
SPACING = " " * 47
sdtblck's avatar
sdtblck committed
43
44


45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def escaped_split(text, sep_char, maxsplit=-1):
    """Split text into a list on occurrences of the given separation
    character `sep_char`. The separation character may be escaped by a
    backslash to avoid splitting at that location.

    The separation character must be a string of size 1.

    If `maxsplit` is given, at most `maxsplit` splits are done (thus,
    the list will have at most `maxsplit + 1` elements). If `maxsplit`
    is not specified or less than 0, then there is no limit on the
    number of splits (all possible splits are made).
    """
    assert (
        len(sep_char) == 1
    ), "separation string must be a single character for escaped splitting"

    if maxsplit == 0:
        return text
    maxsplit = max(0, maxsplit)

    return re.split(r"(?<!\\)" + sep_char, text, maxsplit)


haileyschoelkopf's avatar
haileyschoelkopf committed
68
69
70
71
72
def handle_arg_string(arg):
    if arg.lower() == "true":
        return True
    elif arg.lower() == "false":
        return False
73
74
75
76
77
78
    elif arg.isnumeric():
        return int(arg)
    try:
        return float(arg)
    except ValueError:
        return arg
haileyschoelkopf's avatar
haileyschoelkopf committed
79
80


Jason Phang's avatar
gpt3  
Jason Phang committed
81
82
83
84
85
86
def simple_parse_args_string(args_string):
    """
    Parses something like
        args1=val1,arg2=val2
    Into a dictionary
    """
Jason Phang's avatar
Jason Phang committed
87
    args_string = args_string.strip()
Jason Phang's avatar
gpt3  
Jason Phang committed
88
89
    if not args_string:
        return {}
90
    arg_list = [arg for arg in args_string.split(",") if arg]
haileyschoelkopf's avatar
haileyschoelkopf committed
91
92
93
    args_dict = {
        k: handle_arg_string(v) for k, v in [arg.split("=") for arg in arg_list]
    }
Jason Phang's avatar
gpt3  
Jason Phang committed
94
    return args_dict
Leo Gao's avatar
Leo Gao committed
95

Fabrizio Milo's avatar
Fabrizio Milo committed
96

Leo Gao's avatar
Leo Gao committed
97
98
def join_iters(iters):
    for iter in iters:
Leo Gao's avatar
Leo Gao committed
99
        yield from iter
Leo Gao's avatar
Leo Gao committed
100
101


Ethan Smith's avatar
Ethan Smith committed
102
def chunks(iter, n: int = 0, fn=None):
baberabb's avatar
baberabb committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    """
    Divides an iterable into chunks of specified size or based on a given function.
    Useful for batching

    Parameters:
    - iter: The input iterable to be divided into chunks.
    - n: An integer representing the size of each chunk. Default is 0.
    - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.

    Returns:
    An iterator that yields chunks of the input iterable.

    Example usage:
    ```
    data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    for chunk in chunks(data, 3):
        print(chunk)
    ```
    Output:
    ```
    [1, 2, 3]
    [4, 5, 6]
    [7, 8, 9]
    [10]
    ```
    """
Leo Gao's avatar
Leo Gao committed
129
    arr = []
130
    for i, x in enumerate(iter):
Leo Gao's avatar
Leo Gao committed
131
        arr.append(x)
132
        if len(arr) == (fn(i, iter) if fn else n):
Leo Gao's avatar
Leo Gao committed
133
134
            yield arr
            arr = []
Fabrizio Milo's avatar
Fabrizio Milo committed
135
136
137
138

    if arr:
        yield arr

Leo Gao's avatar
Leo Gao committed
139

140
141
142
143
144
def group(arr, fn):
    res = collections.defaultdict(list)

    for ob in arr:
        res[fn(ob)].append(ob)
Fabrizio Milo's avatar
Fabrizio Milo committed
145

146
147
    return list(res.values())

Fabrizio Milo's avatar
Fabrizio Milo committed
148

gakada's avatar
gakada committed
149
class MultiChoice:
Ethan Smith's avatar
Ethan Smith committed
150
    def __init__(self, choices) -> None:
gakada's avatar
gakada committed
151
152
153
        self.choices = choices

    # Simple wildcard support (linux filename patterns)
Ethan Smith's avatar
Ethan Smith committed
154
    def __contains__(self, values) -> bool:
gakada's avatar
gakada committed
155
        for value in values.split(","):
156
            if len(fnmatch.filter(self.choices, value)) == 0:
157
                eval_logger.info("Available tasks to choose:")
158
159
                for choice in self.choices:
                    eval_logger.info(f"  - {choice}")
160
                raise ValueError("'{}' is not in task list".format(value))
gakada's avatar
gakada committed
161
162
        return True

Ethan Smith's avatar
Ethan Smith committed
163
    def __iter__(self) -> Iterator:
gakada's avatar
gakada committed
164
165
166
167
168
169
170
        for choice in self.choices:
            yield choice


# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
171
    if isinstance(patterns, str):
172
173
        patterns = [patterns]

gakada's avatar
gakada committed
174
175
176
177
178
179
180
    task_names = set()
    for pattern in patterns:
        for matching in fnmatch.filter(source_list, pattern):
            task_names.add(matching)
    return sorted(list(task_names))


Leo Gao's avatar
Leo Gao committed
181
182
183
184
def general_detokenize(string):
    string = string.replace(" n't", "n't")
    string = string.replace(" )", ")")
    string = string.replace("( ", "(")
Fabrizio Milo's avatar
Fabrizio Milo committed
185
186
    string = string.replace('" ', '"')
    string = string.replace(' "', '"')
Leo Gao's avatar
Fix  
Leo Gao committed
187
    string = re.sub(r" (['.,])", r"\1", string)
188
189
190
    return string


Jason Phang's avatar
Jason Phang committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len):
    """
    - context_len allows for a rolling window context, allowing each prediction window to potentially
      condition on some context

    :param token_list: list
        List of tokens to be PREDICTED
    :param max_seq_len: int
        max_seq_len of model (or max_seq_len we want to use)
    :param context_len: int
        Amount of desired token context for prediction. Needs to be at least 1.
    :param prefix_token: token
        Dummy token like <eos> so the first token has something to condition on
    :return: generator
        Generator of tuples
            (input_tokens, pred_tokens)
        Note: Score only the last len(pred_tokens) logits of the LM
    """
    assert 1 <= context_len <= max_seq_len
    if not token_list:
        return
    # +1 offset, going from input->preds
    pred_len = max_seq_len - context_len + 1
    predicted = 0

    # Special handling for first window: predict all tokens
    first_seq_len = min(max_seq_len, len(token_list))
Fabrizio Milo's avatar
Fabrizio Milo committed
218
    yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len])
Jason Phang's avatar
Jason Phang committed
219
220
221
222
223
    predicted += first_seq_len

    while predicted < len(token_list):
        window_pred_len = min(len(token_list) - predicted, pred_len)
        window_end = predicted + window_pred_len
Leo Gao's avatar
Leo Gao committed
224

Jason Phang's avatar
Jason Phang committed
225
        yield (
lintangsutawika's avatar
lintangsutawika committed
226
227
            token_list[window_end - max_seq_len - 1 : window_end - 1],
            token_list[window_end - window_pred_len : window_end],
Jason Phang's avatar
Jason Phang committed
228
229
230
        )
        predicted += window_pred_len

Fabrizio Milo's avatar
Fabrizio Milo committed
231

Leo Gao's avatar
Leo Gao committed
232
def make_disjoint_window(pair):
Fabrizio Milo's avatar
Fabrizio Milo committed
233
    """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
Leo Gao's avatar
Leo Gao committed
234
    a, b = pair
235
    return a[: len(a) - (len(b) - 1)], b
Fabrizio Milo's avatar
Fabrizio Milo committed
236

Jason Phang's avatar
Jason Phang committed
237

238
class Reorderer:
baberabb's avatar
baberabb committed
239
240
241
242
243
244
245
    def __init__(self, arr: List[Any], fn: Callable) -> None:
        """Reorder an array according to some function

        Args:
            arr (List[Any]): The initial array
            fn (Callable[[Any], Any]): A function to determine the priority of elements
        """
246
247
248
        self.size = len(arr)
        arr = list(enumerate(arr))
        arr = group(arr, lambda x: fn(x[1]))
249
250
251
        # arr = [([y[0] for y in x], x[0][1]) for x in arr]
        # TODO: overhaul reorderer. It currently grouped requests by content but we don't want this
        arr = [([y[0]], x[0][1]) for x in arr for y in x]
252
253
254
        arr.sort(key=lambda x: fn(x[1]))

        self.arr = arr
Fabrizio Milo's avatar
Fabrizio Milo committed
255

256
    def get_reordered(self):
baberabb's avatar
baberabb committed
257
258
259
260
261
        """Gets the reordered array

        Returns:
            List[Any]: The reordered array
        """
262
        return [x[1] for x in self.arr]
Fabrizio Milo's avatar
Fabrizio Milo committed
263

264
    def get_original(self, newarr):
baberabb's avatar
baberabb committed
265
266
267
268
269
270
271
272
        """Restores the original order of a new array based on the old array's order

        Args:
            newarr (List[Any]): The array to be restored

        Returns:
            List[Any]: The array restored to the original order
        """
273
274
275
276
        res = [None] * self.size
        cov = [False] * self.size

        for (inds, _), v in zip(self.arr, newarr):
Fabrizio Milo's avatar
Fabrizio Milo committed
277
            for ind in inds:
278
279
                res[ind] = v
                cov[ind] = True
Fabrizio Milo's avatar
Fabrizio Milo committed
280

281
        assert all(cov)
Fabrizio Milo's avatar
Fabrizio Milo committed
282

283
284
        return res

Fabrizio Milo's avatar
Fabrizio Milo committed
285

haileyschoelkopf's avatar
haileyschoelkopf committed
286
287
288
289
290
291
292
class Grouper:
    """
    takes an array `arr` and function `fn` and returns a dictionary
    with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all
    objects in `arr` satisfying `key == fn(ob)`.
    """

Ethan Smith's avatar
Ethan Smith committed
293
    def __init__(self, arr, fn) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
        # self.orig_arr = arr
        self.size = len(arr)
        arr = list(enumerate(arr))

        def group_return_dict(arr, fn):
            res = collections.defaultdict(list)

            for ob in arr:
                res[fn(ob)].append(ob)
            return res

        arr = group_return_dict(arr, lambda x: fn(x[1]))

        # self.arr has format Dict[Tuple[int, <entry from orig. arr>]]
        self.arr = arr
        self._grouped = None

    def get_grouped(self):
        # return the contents but not indices for our grouped dict.
        if self._grouped:
            return self._grouped
        grouped = {}
        for key in self.arr.keys():
            # drop the index from each element of self.arr
            grouped[key] = [y[1] for y in self.arr[key]]
        self._grouped = grouped
        return grouped

    def get_original(self, grouped_dict):
        # take in a grouped dictionary with e.g. results for each key listed
        # in the same order as the instances in `self.arr`, and
        # return the results in the same (single list) order as `self.orig_arr`.
        res = [None] * self.size
        cov = [False] * self.size
        # orig = [None] * self.size

        assert grouped_dict.keys() == self.arr.keys()

        for key in grouped_dict.keys():
            for (ind, _), v in zip(self.arr[key], grouped_dict[key]):
                res[ind] = v
                cov[ind] = True
                # orig[ind] = _

        assert all(cov)
        # assert orig == self.orig_arr

        return res


Ethan Smith's avatar
Ethan Smith committed
344
def make_table(result_dict, column: str = "results"):
345
    """Generate table of results."""
346
    from pytablewriter import LatexTableWriter, MarkdownTableWriter
347

lintangsutawika's avatar
lintangsutawika committed
348
    if column == "results":
lintangsutawika's avatar
lintangsutawika committed
349
350
351
        column_name = "Tasks"
    elif column == "groups":
        column_name = "Groups"
lintangsutawika's avatar
lintangsutawika committed
352

lintangsutawika's avatar
lintangsutawika committed
353
    all_headers = [
lintangsutawika's avatar
lintangsutawika committed
354
        column_name,
lintangsutawika's avatar
lintangsutawika committed
355
356
        "Version",
        "Filter",
357
        "n-shot",
lintangsutawika's avatar
lintangsutawika committed
358
359
360
361
362
        "Metric",
        "Value",
        "",
        "Stderr",
    ]
363

lintangsutawika's avatar
lintangsutawika committed
364
365
366
367
368
    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = all_headers
    latex_writer.headers = all_headers

369
370
    values = []

lintangsutawika's avatar
lintangsutawika committed
371
    for k, dic in result_dict[column].items():
372
        version = result_dict["versions"][k]
373
        n = str(result_dict["n-shot"][k])
374
375
376
377

        if "alias" in dic:
            k = dic.pop("alias")

378
379
        for (mf), v in dic.items():
            m, _, f = mf.partition(",")
380
381
382
            if m.endswith("_stderr"):
                continue

383
384
            if m + "_stderr" + "," + f in dic:
                se = dic[m + "_stderr" + "," + f]
385
386
387
                if se != "N/A":
                    se = "%.4f" % se
                values.append([k, version, f, n, m, "%.4f" % v, "±", se])
388
            else:
389
                values.append([k, version, f, n, m, "%.4f" % v, "", ""])
390
391
392
393
394
395
396
397
398
399
400
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

    return md_writer.dumps()


401
402
def positional_deprecated(fn):
    """
Fabrizio Milo's avatar
Fabrizio Milo committed
403
    A decorator to nudge users into passing only keyword args (`kwargs`) to the
404
405
    wrapped function, `fn`.
    """
Fabrizio Milo's avatar
Fabrizio Milo committed
406

407
408
    @functools.wraps(fn)
    def _wrapper(*args, **kwargs):
Fabrizio Milo's avatar
Fabrizio Milo committed
409
410
411
        if len(args) != 1 if inspect.ismethod(fn) else 0:
            print(
                f"WARNING: using {fn.__name__} with positional arguments is "
412
                "deprecated and will be disallowed in a future version of "
Fabrizio Milo's avatar
Fabrizio Milo committed
413
414
                "lm-evaluation-harness!"
            )
415
        return fn(*args, **kwargs)
Fabrizio Milo's avatar
Fabrizio Milo committed
416

417
    return _wrapper
Stephen Hogg's avatar
Stephen Hogg committed
418

Fabrizio Milo's avatar
Fabrizio Milo committed
419

Stephen Hogg's avatar
Stephen Hogg committed
420
421
422
423
424
425
426
427
428
@positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
    """
    Search upward in the directory tree to a maximum of three layers
    to find and return the package root (containing the 'tests' folder)
    """
    cur_path = start_path.resolve()
    max_layers = 3
    for _ in range(max_layers):
Fabrizio Milo's avatar
Fabrizio Milo committed
429
        if (cur_path / "tests" / "test_version_stable.py").exists():
Stephen Hogg's avatar
Stephen Hogg committed
430
431
432
            return cur_path
        else:
            cur_path = cur_path.parent.resolve()
Fabrizio Milo's avatar
Fabrizio Milo committed
433
434
435
436
    raise FileNotFoundError(
        f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
    )

Stephen Hogg's avatar
Stephen Hogg committed
437
438

@positional_deprecated
439
def run_task_tests(task_list: List[str]):
Stephen Hogg's avatar
Stephen Hogg committed
440
441
442
    """
    Find the package root and run the tests for the given tasks
    """
jon-tow's avatar
jon-tow committed
443
444
    import pytest

445
    package_root = find_test_root(start_path=pathlib.Path(__file__))
Fabrizio Milo's avatar
Fabrizio Milo committed
446
447
448
449
450
451
452
    task_string = " or ".join(task_list)
    args = [
        f"{package_root}/tests/test_version_stable.py",
        f"--rootdir={package_root}",
        "-k",
        f"{task_string}",
    ]
Stephen Hogg's avatar
Stephen Hogg committed
453
454
455
    sys.path.append(str(package_root))
    pytest_return_val = pytest.main(args)
    if pytest_return_val:
Fabrizio Milo's avatar
Fabrizio Milo committed
456
457
458
        raise ValueError(
            f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
        )
459
460


461
462
463
464
465
466
def get_git_commit_hash():
    """
    Gets the git commit hash of your current repo (if it exists).
    Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42
    """
    try:
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
467
        git_hash = subprocess.check_output(["git", "describe", "--always"]).strip()
468
        git_hash = git_hash.decode()
469
470
    except subprocess.CalledProcessError or FileNotFoundError:
        # FileNotFoundError occurs when git not installed on system
471
472
473
474
        git_hash = None
    return git_hash


lintangsutawika's avatar
lintangsutawika committed
475
476
477
478
def import_function(loader, node):
    function_name = loader.construct_scalar(node)
    yaml_path = os.path.dirname(loader.name)

lintangsutawika's avatar
lintangsutawika committed
479
    *module_name, function_name = function_name.split(".")
480
    if isinstance(module_name, list):
lintangsutawika's avatar
lintangsutawika committed
481
482
        module_name = ".".join(module_name)
    module_path = os.path.normpath(os.path.join(yaml_path, "{}.py".format(module_name)))
lintangsutawika's avatar
lintangsutawika committed
483
484
485
486
487
488
489
490

    spec = importlib.util.spec_from_file_location(module_name, module_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)

    function = getattr(module, function_name)
    return function

lintangsutawika's avatar
lintangsutawika committed
491

lintangsutawika's avatar
lintangsutawika committed
492
# Add the import_function constructor to the YAML loader
lintangsutawika's avatar
lintangsutawika committed
493
yaml.add_constructor("!function", import_function)
lintangsutawika's avatar
lintangsutawika committed
494
495


496
497
498
499
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
    if yaml_config is None:
        with open(yaml_path, "rb") as file:
            yaml_config = yaml.full_load(file)
lintangsutawika's avatar
lintangsutawika committed
500

lintangsutawika's avatar
lintangsutawika committed
501
502
    if yaml_dir is None:
        yaml_dir = os.path.dirname(yaml_path)
503
504
505
506
507
508
509

    assert yaml_dir is not None

    if "include" in yaml_config:
        include_path = yaml_config["include"]
        del yaml_config["include"]

510
        if isinstance(include_path, str):
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
            include_path = [include_path]

        # Load from the last one first
        include_path.reverse()
        final_yaml_config = {}
        for path in include_path:
            # Assumes that path is a full path.
            # If not found, assume the included yaml
            # is in the same dir as the original yaml
            if not os.path.isfile(path):
                path = os.path.join(yaml_dir, path)

            try:
                included_yaml_config = load_yaml_config(path)
                final_yaml_config.update(included_yaml_config)
            except Exception as ex:
                # If failed to load, ignore
                raise ex

        final_yaml_config.update(yaml_config)
        return final_yaml_config
    return yaml_config
lintangsutawika's avatar
lintangsutawika committed
533
534


Ethan Smith's avatar
Ethan Smith committed
535
def regex_replace(string, pattern, repl, count: int = 0):
536
537
    """Implements the `re.sub` function as a custom Jinja filter."""
    return re.sub(pattern, repl, string, count=count)
lintangsutawika's avatar
lintangsutawika committed
538

lintangsutawika's avatar
lintangsutawika committed
539

540
env = Environment(loader=BaseLoader, undefined=StrictUndefined)
541
env.filters["regex_replace"] = regex_replace
542
543


baberabb's avatar
baberabb committed
544
def apply_template(template: str, doc: dict) -> str:
545
546
    rtemplate = env.from_string(template)
    return rtemplate.render(**doc)
547
548


549
550
551
552
def create_iterator(raw_iterator, rank, world_size, limit=None):
    """
    Method for creating a (potentially) sliced and limited
    iterator from a raw document iterator. Used for splitting data
553
554
555
    among ranks in multigpu setting or only pulling a sample of documents
    """
    return islice(raw_iterator, rank, limit, world_size)
556
557


haileyschoelkopf's avatar
haileyschoelkopf committed
558
559
560
561
562
def pad_and_concat(
    max_length: int,
    tensors: List[torch.Tensor],
    padding_side: Literal["right", "left"] = "right",
):
haileyschoelkopf's avatar
haileyschoelkopf committed
563
564
565
566
    """
    Method for padding a list of tensors given the maximum tensor
    length in the batch. Used for batching inputs and continuations in
    seq2seq models.
lintangsutawika's avatar
lintangsutawika committed
567
    """
haileyschoelkopf's avatar
haileyschoelkopf committed
568
569
570
    assert (
        padding_side == "left" or padding_side == "right"
    ), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
haileyschoelkopf's avatar
haileyschoelkopf committed
571

lintangsutawika's avatar
lintangsutawika committed
572
    for i, tensor in enumerate(tensors):
573
574
        if len(tensor.shape) == 2:
            tensor = tensor.squeeze(0)  # squeeze, in case passed [1, seq] size
lintangsutawika's avatar
lintangsutawika committed
575
576
        tensor_len = tensor.shape[0]
        if tensor_len < max_length:
haileyschoelkopf's avatar
haileyschoelkopf committed
577
578
579
            if padding_side == "right":
                # right-pad
                tensors[i] = torch.cat(
haileyschoelkopf's avatar
haileyschoelkopf committed
580
581
582
583
584
585
586
587
588
589
                    [
                        tensor,  # [seq]
                        torch.zeros(
                            max_length - tensor_len,
                            dtype=torch.long,
                            device=tensor.device,
                        ),  # [padding_length - seq]
                    ],
                    dim=0,
                ).unsqueeze(0)
haileyschoelkopf's avatar
haileyschoelkopf committed
590
591
592
593
            else:
                # left-pad
                tensors[i] = torch.cat(
                    [
594
                        torch.zeros(
haileyschoelkopf's avatar
haileyschoelkopf committed
595
                            max_length - tensor_len,
596
597
                            dtype=torch.long,
                            device=tensor.device,
haileyschoelkopf's avatar
haileyschoelkopf committed
598
                        ),  # [padding_length - seq]
haileyschoelkopf's avatar
haileyschoelkopf committed
599
                        tensor,  # [seq]
haileyschoelkopf's avatar
haileyschoelkopf committed
600
601
602
                    ],
                    dim=0,
                ).unsqueeze(0)
lintangsutawika's avatar
lintangsutawika committed
603
604
605
        else:
            tensors[i] = tensor.unsqueeze(0)

haileyschoelkopf's avatar
haileyschoelkopf committed
606
    return torch.cat(tensors, dim=0)
haileyschoelkopf's avatar
haileyschoelkopf committed
607
608


Ethan Smith's avatar
Ethan Smith committed
609
def clear_torch_cache() -> None:
610
611
    gc.collect()
    torch.cuda.empty_cache()
haileyschoelkopf's avatar
haileyschoelkopf committed
612
613


lintangsutawika's avatar
lintangsutawika committed
614
615
616
617
618
619
620
621
622
623
def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
    """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
    if isinstance(dtype, str) and dtype != "auto":
        # Convert `str` args torch dtype: `float16` -> `torch.float16`
        _torch_dtype = getattr(torch, dtype)
    else:
        _torch_dtype = dtype
    return _torch_dtype


haileyschoelkopf's avatar
haileyschoelkopf committed
624
# Multi-token stopping criteria
haileyschoelkopf's avatar
haileyschoelkopf committed
625
626
627
628
629
630
631
632
633
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
    """Criteria to stop on the specified multi-token sequence."""

    def __init__(
        self,
        sequence: str,
        tokenizer: transformers.PreTrainedTokenizer,
        initial_decoder_input_length: int,
        batch_size: int,
Ethan Smith's avatar
Ethan Smith committed
634
    ) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
635
636
637
638
        self.initial_decoder_input_length = initial_decoder_input_length
        self.done_tracker = [False] * batch_size
        self.sequence = sequence
        self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
639
640
641
642
643
644
645
646
        # we look back for 2 more tokens than it takes to encode our stop sequence
        # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
        # and we don't want to mistakenly not stop a generation because our
        # (string) stop sequence was output in a different tokenization

        # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
        # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
        self.sequence_id_len = len(self.sequence_ids) + 2
haileyschoelkopf's avatar
haileyschoelkopf committed
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
        self.tokenizer = tokenizer

    def __call__(self, input_ids, scores, **kwargs) -> bool:
        # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
        lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][
            :, -self.sequence_id_len :
        ]

        lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
        for i, done in enumerate(self.done_tracker):
            if not done:
                self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
        return False not in self.done_tracker


def stop_sequences_criteria(
    tokenizer: transformers.PreTrainedTokenizer,
    stop_sequences: List[str],
    initial_decoder_input_length: int,
    batch_size: int,
) -> transformers.StoppingCriteriaList:
    return transformers.StoppingCriteriaList(
        [
            *[
                MultiTokenEOSCriteria(
                    sequence, tokenizer, initial_decoder_input_length, batch_size
                )
                for sequence in stop_sequences
            ],
        ]
    )
baberabb's avatar
baberabb committed
678
679
680
681
682
683
684


# from more_itertools
def divide(iterable, n) -> List[Iterator]:
    """Divide the elements from *iterable* into *n* parts, maintaining
    order.

685
        >>> group_1, group_2 = divide([1, 2, 3, 4, 5, 6], 2)
baberabb's avatar
baberabb committed
686
687
688
689
690
691
692
693
        >>> list(group_1)
        [1, 2, 3]
        >>> list(group_2)
        [4, 5, 6]

    If the length of *iterable* is not evenly divisible by *n*, then the
    length of the returned iterables will not be identical:

694
        >>> children = divide([1, 2, 3, 4, 5, 6, 7], 3)
baberabb's avatar
baberabb committed
695
696
697
698
699
700
        >>> [list(c) for c in children]
        [[1, 2, 3], [4, 5], [6, 7]]

    If the length of the iterable is smaller than n, then the last returned
    iterables will be empty:

701
        >>> children = divide([1, 2, 3], 5)
baberabb's avatar
baberabb committed
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
        >>> [list(c) for c in children]
        [[1], [2], [3], [], []]

    This function will exhaust the iterable before returning and may require
    significant storage. If order is not important, see :func:`distribute`,
    which does not first pull the iterable into memory.

    """
    if n < 1:
        raise ValueError("n must be at least 1")

    try:
        iterable[:0]
    except TypeError:
        seq = tuple(iterable)
    else:
        seq = iterable

    q, r = divmod(len(seq), n)

    ret = []
    stop = 0
    for i in range(1, n + 1):
        start = stop
        stop += q + 1 if i <= r else q
        ret.append(iter(seq[start:stop]))

    return ret
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769


def retry_on_specific_exceptions(
    on_exceptions: List[Type[Exception]],
    max_retries: Optional[int] = None,
    backoff_time: float = 3.0,
    backoff_multiplier: float = 1.5,
    on_exception_callback: Optional[Callable[[Exception, float], Any]] = None,
):
    """Retry on an LLM Provider's rate limit error with exponential backoff
    For example, to use for OpenAI, do the following:
    ```
    from openai import RateLimitError

    # Recommend specifying max_retries to avoid infinite loops!
    @retry_on_specific_exceptions([RateLimitError], max_retries=3)
    def completion(...):
        # Wrap OpenAI completion function here
        ...
    ```
    """

    def decorator(func: Callable):
        @wraps(func)
        def wrapper(*args, **kwargs):
            sleep_time = backoff_time
            attempt = 0
            while max_retries is None or attempt < max_retries:
                try:
                    return func(*args, **kwargs)
                except tuple(on_exceptions) as e:
                    if on_exception_callback is not None:
                        on_exception_callback(e, sleep_time)
                    time.sleep(sleep_time)
                    sleep_time *= backoff_multiplier
                    attempt += 1

        return wrapper

    return decorator
Baber Abbasi's avatar
Baber Abbasi committed
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861


class Collator:
    """
    A class for reordering and batching elements of an array.

    This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
    """

    def __init__(
        self,
        arr: List,
        sort_fn: Callable,
        group_fn: Callable = lambda x: x[1],
        grouping: bool = False,
    ) -> None:
        self.grouping = grouping
        self.fn = sort_fn
        self.group_fn = lambda x: group_fn(x[1])  # first index are enumerated indices
        self.reorder_indices: List = []
        self.size = len(arr)
        self.arr_with_indices: Iterable[Any] = tuple(enumerate(arr))  # [indices, (arr)]
        if self.grouping is True:
            self.group_by_index()

    def group_by_index(self) -> None:
        self.arr_with_indices = self.group(
            self.arr_with_indices, fn=self.group_fn, values=False
        )

    def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
        """
        Generates and yields batches from the reordered array.

        Parameters:
        - n (int): The size of each batch. Defaults to 1.
        - batch_fn (Optional[Callable[[int, Iterable], int]]): A function to determine the size of each batch. Defaults to None.

        Yields:
        Iterator: An iterator over batches of reordered elements.
        """
        if self.grouping:
            for (
                key,
                values,
            ) in self.arr_with_indices.items():  # type: ignore
                values = self._reorder(values)
                batch = self.get_chunks(values, n=n, fn=batch_fn)
                yield from batch
        else:
            values = self._reorder(self.arr_with_indices)  # type: ignore
            batch = self.get_chunks(values, n=n, fn=batch_fn)
            yield from batch

    def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> List:
        """
        Reorders the elements in the array based on the sorting function.

        Parameters:
        - arr (Union[List, Tuple[Tuple[int, Any], ...]]): The array or iterable to be reordered.

        Yields:
        List: Yields reordered elements one by one.
        """
        arr = sorted(arr, key=lambda x: self.fn(x[1]))
        self.reorder_indices.extend([x[0] for x in arr])
        yield from [x[1] for x in arr]

    def get_original(self, newarr: List) -> List:
        """
        Restores the original order of elements from the reordered list.

        Parameters:
        - newarr (List): The reordered array.

        Returns:
        List: The array with elements restored to their original order.
        """
        res = [None] * self.size
        cov = [False] * self.size

        for ind, v in zip(self.reorder_indices, newarr):
            res[ind] = v
            cov[ind] = True

        assert all(cov)

        return res

    def __len__(self):
        return self.size

Baber Abbasi's avatar
Baber Abbasi committed
862
863
    @staticmethod
    def group(arr: Iterable, fn: Callable, values: bool = False) -> Iterable:
Baber Abbasi's avatar
Baber Abbasi committed
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
        """
        Groups elements of an iterable based on a provided function.

        Parameters:
        - arr (Iterable): The iterable to be grouped.
        - fn (Callable): The function to determine the grouping.
        - values (bool): If True, returns the values of the group. Defaults to False.

        Returns:
        Iterable: An iterable of grouped elements.
        """
        res = collections.defaultdict(list)
        for ob in arr:
            try:
                hashable_dict = tuple(
Baber Abbasi's avatar
Baber Abbasi committed
879
880
881
882
883
884
885
                    (
                        key,
                        tuple(value)
                        if isinstance(value, collections.abc.Iterable)
                        else value,
                    )
                    for key, value in sorted(fn(ob).items())
Baber Abbasi's avatar
Baber Abbasi committed
886
887
888
889
890
891
892
893
                )
                res[hashable_dict].append(ob)
            except TypeError:
                res[fn(ob)].append(ob)
        if not values:
            return res
        return res.values()

Baber Abbasi's avatar
Baber Abbasi committed
894
895
    @staticmethod
    def get_chunks(_iter, n: int = 0, fn=None):
Baber Abbasi's avatar
Baber Abbasi committed
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
        """
        Divides an iterable into chunks of specified size or based on a given function.
        Useful for batching

        Parameters:
        - iter: The input iterable to be divided into chunks.
        - n: An integer representing the size of each chunk. Default is 0.
        - fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.

        Returns:
        An iterator that yields chunks of the input iterable.

        Example usage:
        ```
        data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        for chunk in chunks(data, 3):
            print(chunk)
        ```
        Output:
        ```
        [1, 2, 3]
        [4, 5, 6]
        [7, 8, 9]
        [10]
        ```
        """
        arr = []
923
        _iter = tuple(_iter)
Baber Abbasi's avatar
Baber Abbasi committed
924
        for i, x in enumerate(_iter):
Baber Abbasi's avatar
Baber Abbasi committed
925
            arr.append(x)
Baber Abbasi's avatar
Baber Abbasi committed
926
            if len(arr) == (fn(i, _iter) if fn else n):
Baber Abbasi's avatar
Baber Abbasi committed
927
928
929
930
931
                yield arr
                arr = []

        if arr:
            yield arr