utils.py 20.8 KB
Newer Older
Baber's avatar
Baber committed
1
2
from __future__ import annotations

3
4
import collections
import fnmatch
5
import hashlib
6
import importlib.util
7
import inspect
8
import json
9
10
11
import logging
import os
import re
Baber's avatar
Baber committed
12
from collections.abc import Generator
13
from dataclasses import asdict, is_dataclass
Baber's avatar
Baber committed
14
from functools import lru_cache, partial, wraps
15
from itertools import islice
16
from pathlib import Path
Baber's avatar
Baber committed
17
from typing import Any, Callable
18

Lintang Sutawika's avatar
Lintang Sutawika committed
19
import numpy as np
20
import yaml
Baber's avatar
Baber committed
21
from jinja2 import BaseLoader, Environment, StrictUndefined, Template
sdtblck's avatar
sdtblck committed
22

lintangsutawika's avatar
lintangsutawika committed
23

24
SPACING = " " * 47
sdtblck's avatar
sdtblck committed
25

26
27
28
29
30
HIGHER_IS_BETTER_SYMBOLS = {
    True: "↑",
    False: "↓",
}

sdtblck's avatar
sdtblck committed
31

Baber's avatar
Baber committed
32
def wrap_text(string: str, width: int = 140, **kwargs) -> str | None:
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    """
    Wraps the given string to the specified width.
    """
    import textwrap

    return textwrap.fill(
        inspect.cleandoc(string),
        width=width,
        initial_indent="",
        subsequent_indent=" " * 8,
        break_long_words=False,
        break_on_hyphens=False,
        **kwargs,
    )


49
def get_logger(level: str | None = None) -> logging.Logger:
Baber's avatar
Baber committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    """
    Get a logger with a stream handler that captures all lm_eval logs.

    Args:
        level (Optional[str]): The logging level.
    Example:
        >>> logger = get_logger("INFO")
        >>> logger.info("Log this")
        INFO:lm_eval:Log this!

    Returns:
        logging.Logger: The logger.
    """
    logger = logging.getLogger("lm_eval")
    if not logger.hasHandlers():
        logger.addHandler(logging.StreamHandler())
        logger.setLevel(logging.INFO)
    if level is not None:
        level = getattr(logging, level.upper())
        logger.setLevel(level)
    return logger


def setup_logging(verbosity=logging.INFO, suppress_third_party=True):
    """
    Configure logging for the lm_eval CLI application.

    WARNING: This function is intended for CLI use only. Library users should
    use get_logger() instead to avoid interfering with their application's
    logging configuration.

    Args:
        verbosity: Log level (int) or string name. Can be overridden by LOGLEVEL env var.
        suppress_third_party: Whether to suppress verbose third-party library logs.

    Returns:
        logging.Logger: The configured lm_eval logger instance.
    """
    # Validate verbosity parameter
    if isinstance(verbosity, str):
        level_map = {
            "DEBUG": logging.DEBUG,
            "INFO": logging.INFO,
            "WARNING": logging.WARNING,
            "ERROR": logging.ERROR,
            "CRITICAL": logging.CRITICAL,
        }
        verbosity = level_map.get(verbosity.upper(), logging.INFO)
    elif not isinstance(verbosity, int):
        verbosity = logging.INFO

    # Get log level from environment or use default
    if log_level_env := os.environ.get("LOGLEVEL", None):
        level_map = {
            "DEBUG": logging.DEBUG,
            "INFO": logging.INFO,
            "WARNING": logging.WARNING,
            "ERROR": logging.ERROR,
            "CRITICAL": logging.CRITICAL,
        }
        log_level = level_map.get(log_level_env.upper(), verbosity)
    else:
        log_level = verbosity

    # Get the lm_eval logger directly
    logger = logging.getLogger("lm_eval")

    # Configure custom formatter
Baber Abbasi's avatar
Baber Abbasi committed
118
119
    class CustomFormatter(logging.Formatter):
        def format(self, record):
Baber's avatar
Baber committed
120
            record.name = record.name.removeprefix("im_eval.")
Baber Abbasi's avatar
Baber Abbasi committed
121
122
123
124
125
126
127
            return super().format(record)

    formatter = CustomFormatter(
        "%(asctime)s %(levelname)-8s [%(name)s:%(lineno)d] %(message)s",
        datefmt="%Y-%m-%d:%H:%M:%S",
    )

Baber's avatar
Baber committed
128
129
130
131
132
    # Check if handler already exists to prevent duplicates
    has_stream_handler = any(
        isinstance(h, logging.StreamHandler) for h in logger.handlers
    )
    if not has_stream_handler:
Baber Abbasi's avatar
Baber Abbasi committed
133
134
        handler = logging.StreamHandler()
        handler.setFormatter(formatter)
Baber's avatar
Baber committed
135
136
137
        logger.addHandler(handler)
        # For CLI use, we disable propagation to avoid duplicate messages
        logger.propagate = False
Baber Abbasi's avatar
Baber Abbasi committed
138

Baber's avatar
Baber committed
139
140
    # Set the logger level
    logger.setLevel(log_level)
Baber Abbasi's avatar
Baber Abbasi committed
141

Baber's avatar
Baber committed
142
143
144
145
146
147
148
    # Optionally suppress verbose third-party library logs
    if suppress_third_party and log_level == logging.DEBUG:
        third_party_loggers = ["urllib3", "filelock", "fsspec"]
        for logger_name in third_party_loggers:
            logging.getLogger(logger_name).setLevel(logging.INFO)

    return logger
Lintang Sutawika's avatar
Lintang Sutawika committed
149
150


151
152
153
154
def hash_string(string: str) -> str:
    return hashlib.sha256(string.encode("utf-8")).hexdigest()


155
156
157
158
159
160
161
162
163
164
165
166
def escaped_split(text, sep_char, maxsplit=-1):
    """Split text into a list on occurrences of the given separation
    character `sep_char`. The separation character may be escaped by a
    backslash to avoid splitting at that location.

    The separation character must be a string of size 1.

    If `maxsplit` is given, at most `maxsplit` splits are done (thus,
    the list will have at most `maxsplit + 1` elements). If `maxsplit`
    is not specified or less than 0, then there is no limit on the
    number of splits (all possible splits are made).
    """
Baber Abbasi's avatar
Baber Abbasi committed
167
168
169
    assert len(sep_char) == 1, (
        "separation string must be a single character for escaped splitting"
    )
170
171
172
173
174

    if maxsplit == 0:
        return text
    maxsplit = max(0, maxsplit)

Baber's avatar
Baber committed
175
    return re.split(r"(?<!\\)" + sep_char, text, maxsplit=maxsplit)
176
177


haileyschoelkopf's avatar
haileyschoelkopf committed
178
179
180
181
182
def handle_arg_string(arg):
    if arg.lower() == "true":
        return True
    elif arg.lower() == "false":
        return False
183
184
185
186
187
188
    elif arg.isnumeric():
        return int(arg)
    try:
        return float(arg)
    except ValueError:
        return arg
haileyschoelkopf's avatar
haileyschoelkopf committed
189
190


191
def handle_non_serializable(o):
Baber's avatar
Baber committed
192
    if isinstance(o, np.integer):
193
194
195
196
197
198
199
        return int(o)
    elif isinstance(o, set):
        return list(o)
    else:
        return str(o)


200
201
202
203
204
205
206
207
208
209
210
211
def sanitize_list(sub):
    """
    Takes possible nested list and recursively converts all inner component to strings
    """
    if isinstance(sub, list):
        return [sanitize_list(item) for item in sub]
    if isinstance(sub, tuple):
        return tuple(sanitize_list(item) for item in sub)
    else:
        return str(sub)


Baber's avatar
Baber committed
212
def simple_parse_args_string(args_string: str | None) -> dict:
Jason Phang's avatar
gpt3  
Jason Phang committed
213
214
215
216
217
    """
    Parses something like
        args1=val1,arg2=val2
    Into a dictionary
    """
Baber Abbasi's avatar
Baber Abbasi committed
218
219
    if args_string is None:
        return {}
Jason Phang's avatar
Jason Phang committed
220
    args_string = args_string.strip()
Jason Phang's avatar
gpt3  
Jason Phang committed
221
222
    if not args_string:
        return {}
223
    arg_list = [arg for arg in args_string.split(",") if arg]
haileyschoelkopf's avatar
haileyschoelkopf committed
224
    args_dict = {
225
226
        kv[0]: handle_arg_string("=".join(kv[1:]))
        for kv in [arg.split("=") for arg in arg_list]
haileyschoelkopf's avatar
haileyschoelkopf committed
227
    }
Jason Phang's avatar
gpt3  
Jason Phang committed
228
    return args_dict
Leo Gao's avatar
Leo Gao committed
229

Fabrizio Milo's avatar
Fabrizio Milo committed
230

Leo Gao's avatar
Leo Gao committed
231
232
def join_iters(iters):
    for iter in iters:
Leo Gao's avatar
Leo Gao committed
233
        yield from iter
Leo Gao's avatar
Leo Gao committed
234
235


236
237
238
239
240
def group(arr, fn):
    res = collections.defaultdict(list)

    for ob in arr:
        res[fn(ob)].append(ob)
Fabrizio Milo's avatar
Fabrizio Milo committed
241

242
243
    return list(res.values())

Fabrizio Milo's avatar
Fabrizio Milo committed
244

gakada's avatar
gakada committed
245
246
# Returns a list containing all values of the source_list that
# match at least one of the patterns
Baber's avatar
Baber committed
247
def pattern_match(patterns: list[str], source_list: list[str]) -> list[str]:
248
    if isinstance(patterns, str):
249
250
        patterns = [patterns]

gakada's avatar
gakada committed
251
252
253
254
255
256
257
    task_names = set()
    for pattern in patterns:
        for matching in fnmatch.filter(source_list, pattern):
            task_names.add(matching)
    return sorted(list(task_names))


Baber Abbasi's avatar
Baber Abbasi committed
258
def softmax(x) -> np.ndarray:
Lintang Sutawika's avatar
Lintang Sutawika committed
259
260
261
262
263
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()


Baber's avatar
Baber committed
264
def general_detokenize(string: str) -> str:
Leo Gao's avatar
Leo Gao committed
265
266
267
    string = string.replace(" n't", "n't")
    string = string.replace(" )", ")")
    string = string.replace("( ", "(")
Fabrizio Milo's avatar
Fabrizio Milo committed
268
269
    string = string.replace('" ', '"')
    string = string.replace(' "', '"')
Leo Gao's avatar
Fix  
Leo Gao committed
270
    string = re.sub(r" (['.,])", r"\1", string)
271
272
273
    return string


274
275
276
277
278
279
280
281
282
283
284
def get_file_task_name(filename: str) -> str:
    """
    Given the sample results filenames, extracts and returns the task name.
    """
    return filename[filename.find("_") + 1 : filename.rfind("_")]


def get_file_datetime(filename: str) -> str:
    """
    Given the results and sample results filenames, extracts and returns the datetime.
    """
285
    return filename[filename.rfind("_") + 1 :].replace(".jsonl", "")
286
287
288
289
290
291


def sanitize_model_name(model_name: str) -> str:
    """
    Given the model name, returns a sanitized version of it.
    """
Baber's avatar
Baber committed
292
    return re.sub(r"[\"<>:/|\\?*\[\]]+", "__", model_name)
293
294
295
296
297
298
299
300
301


def sanitize_task_name(task_name: str) -> str:
    """
    Given the task name, returns a sanitized version of it.
    """
    return re.sub(r"\W", "_", task_name)


Baber's avatar
Baber committed
302
def get_latest_filename(filenames: list[str]) -> str:
303
304
305
306
307
308
    """
    Given a list of filenames, returns the filename with the latest datetime.
    """
    return max(filenames, key=lambda f: get_file_datetime(f))


Baber's avatar
Baber committed
309
def get_results_filenames(filenames: list[str]) -> list[str]:
310
311
312
313
314
315
    """
    Extracts filenames that correspond to aggregated results.
    """
    return [f for f in filenames if "/results_" in f and ".json" in f]


Baber's avatar
Baber committed
316
def get_sample_results_filenames(filenames: list[str]) -> list[str]:
317
318
319
320
321
322
    """
    Extracts filenames that correspond to sample results.
    """
    return [f for f in filenames if "/samples_" in f and ".json" in f]


323
def get_rolling_token_windows(
Baber's avatar
Baber committed
324
325
    token_list: list[int], prefix_token: int, max_seq_len: int, context_len: int
) -> Generator[tuple[list[int], list[int]], None, None]:
Jason Phang's avatar
Jason Phang committed
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
    """
    - context_len allows for a rolling window context, allowing each prediction window to potentially
      condition on some context

    :param token_list: list
        List of tokens to be PREDICTED
    :param max_seq_len: int
        max_seq_len of model (or max_seq_len we want to use)
    :param context_len: int
        Amount of desired token context for prediction. Needs to be at least 1.
    :param prefix_token: token
        Dummy token like <eos> so the first token has something to condition on
    :return: generator
        Generator of tuples
            (input_tokens, pred_tokens)
        Note: Score only the last len(pred_tokens) logits of the LM
    """
    assert 1 <= context_len <= max_seq_len
    if not token_list:
        return
    # +1 offset, going from input->preds
    pred_len = max_seq_len - context_len + 1
    predicted = 0

    # Special handling for first window: predict all tokens
    first_seq_len = min(max_seq_len, len(token_list))
352
    yield [prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len]
Jason Phang's avatar
Jason Phang committed
353
354
355
356
357
    predicted += first_seq_len

    while predicted < len(token_list):
        window_pred_len = min(len(token_list) - predicted, pred_len)
        window_end = predicted + window_pred_len
Leo Gao's avatar
Leo Gao committed
358

Jason Phang's avatar
Jason Phang committed
359
        yield (
lintangsutawika's avatar
lintangsutawika committed
360
361
            token_list[window_end - max_seq_len - 1 : window_end - 1],
            token_list[window_end - window_pred_len : window_end],
Jason Phang's avatar
Jason Phang committed
362
363
364
        )
        predicted += window_pred_len

Fabrizio Milo's avatar
Fabrizio Milo committed
365

366
def make_disjoint_window(
Baber's avatar
Baber committed
367
368
    pair: tuple[list[int], list[int]],
) -> tuple[list[int], list[int]]:
Fabrizio Milo's avatar
Fabrizio Milo committed
369
    """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
Leo Gao's avatar
Leo Gao committed
370
    a, b = pair
371
    return a[: len(a) - (len(b) - 1)], b
Fabrizio Milo's avatar
Fabrizio Milo committed
372

Jason Phang's avatar
Jason Phang committed
373

374
375
376
377
378
379
380
381
382
383
384
385
class EnhancedJSONEncoder(json.JSONEncoder):
    """
    Provides a proper json encoding for the loggers and trackers json dumps.
    Notably manages the json encoding of dataclasses.
    """

    def default(self, o):
        if is_dataclass(o):
            return asdict(o)
        return super().default(o)


386
class Reorderer:
Baber's avatar
Baber committed
387
    def __init__(self, arr: list[Any], fn: Callable) -> None:
baberabb's avatar
baberabb committed
388
389
390
391
392
393
        """Reorder an array according to some function

        Args:
            arr (List[Any]): The initial array
            fn (Callable[[Any], Any]): A function to determine the priority of elements
        """
394
395
396
        self.size = len(arr)
        arr = list(enumerate(arr))
        arr = group(arr, lambda x: fn(x[1]))
397
398
399
        # arr = [([y[0] for y in x], x[0][1]) for x in arr]
        # TODO: overhaul reorderer. It currently grouped requests by content but we don't want this
        arr = [([y[0]], x[0][1]) for x in arr for y in x]
400
401
402
        arr.sort(key=lambda x: fn(x[1]))

        self.arr = arr
Fabrizio Milo's avatar
Fabrizio Milo committed
403

404
    def get_reordered(self):
baberabb's avatar
baberabb committed
405
406
407
408
409
        """Gets the reordered array

        Returns:
            List[Any]: The reordered array
        """
410
        return [x[1] for x in self.arr]
Fabrizio Milo's avatar
Fabrizio Milo committed
411

412
    def get_original(self, newarr):
baberabb's avatar
baberabb committed
413
414
415
416
417
418
419
420
        """Restores the original order of a new array based on the old array's order

        Args:
            newarr (List[Any]): The array to be restored

        Returns:
            List[Any]: The array restored to the original order
        """
421
422
423
424
        res = [None] * self.size
        cov = [False] * self.size

        for (inds, _), v in zip(self.arr, newarr):
Fabrizio Milo's avatar
Fabrizio Milo committed
425
            for ind in inds:
426
427
                res[ind] = v
                cov[ind] = True
Fabrizio Milo's avatar
Fabrizio Milo committed
428

429
        assert all(cov)
Fabrizio Milo's avatar
Fabrizio Milo committed
430

431
432
        return res

Fabrizio Milo's avatar
Fabrizio Milo committed
433

Lintang Sutawika's avatar
Lintang Sutawika committed
434
def make_table(result_dict, column: str = "results", sort_results: bool = False):
435
    """Generate table of results."""
436
    from pytablewriter import LatexTableWriter, MarkdownTableWriter
437

lintangsutawika's avatar
lintangsutawika committed
438
    if column == "results":
lintangsutawika's avatar
lintangsutawika committed
439
440
441
        column_name = "Tasks"
    elif column == "groups":
        column_name = "Groups"
lintangsutawika's avatar
lintangsutawika committed
442

lintangsutawika's avatar
lintangsutawika committed
443
    all_headers = [
lintangsutawika's avatar
lintangsutawika committed
444
        column_name,
lintangsutawika's avatar
lintangsutawika committed
445
446
        "Version",
        "Filter",
447
        "n-shot",
lintangsutawika's avatar
lintangsutawika committed
448
        "Metric",
449
        "",
lintangsutawika's avatar
lintangsutawika committed
450
451
452
453
        "Value",
        "",
        "Stderr",
    ]
454

lintangsutawika's avatar
lintangsutawika committed
455
456
457
458
459
    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = all_headers
    latex_writer.headers = all_headers

460
461
    values = []

462
463
    keys = result_dict[column].keys()
    if sort_results:
Lintang Sutawika's avatar
Lintang Sutawika committed
464
465
466
        # sort entries alphabetically by task or group name.
        # NOTE: we default here to false, because order matters for multi-level table printing a la mmlu.
        # sorting here would mess that up
467
468
469
        keys = sorted(keys)
    for k in keys:
        dic = result_dict[column][k]
Lintang Sutawika's avatar
Lintang Sutawika committed
470
471
        version = result_dict["versions"].get(k, "    N/A")
        n = str(result_dict.get("n-shot", " ").get(k, " "))
Baber's avatar
Baber committed
472
473
        # TODO: fix this
        # higher_is_better = result_dict.get("higher_is_better", {}).get(k, {})
474
475
476
477

        if "alias" in dic:
            k = dic.pop("alias")

478
        metric_items = dic.items()
Lintang Sutawika's avatar
Lintang Sutawika committed
479
        metric_items = sorted(metric_items)
480
481

        for (mf), v in metric_items:
482
            m, _, f = mf.partition(",")
483
484
485
            if m.endswith("_stderr"):
                continue

Baber's avatar
Baber committed
486
487
488
            # hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "")
            # TODO: fix
            hib = "↑"
489

Baber's avatar
Baber committed
490
            v = f"{v:.4f}" if isinstance(v, float) else v
Lintang Sutawika's avatar
Lintang Sutawika committed
491

492
493
            if m + "_stderr" + "," + f in dic:
                se = dic[m + "_stderr" + "," + f]
Baber's avatar
Baber committed
494
                se = "   N/A" if se == "N/A" else f"{se:.4f}"
Lintang Sutawika's avatar
Lintang Sutawika committed
495
                values.append([k, version, f, n, m, hib, v, "±", se])
496
            else:
Lintang Sutawika's avatar
Lintang Sutawika committed
497
                values.append([k, version, f, n, m, hib, v, "", ""])
498
499
500
501
502
503
504
505
506
507
508
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

    return md_writer.dumps()


509
510
def positional_deprecated(fn):
    """
Fabrizio Milo's avatar
Fabrizio Milo committed
511
    A decorator to nudge users into passing only keyword args (`kwargs`) to the
512
513
    wrapped function, `fn`.
    """
Fabrizio Milo's avatar
Fabrizio Milo committed
514

Baber's avatar
Baber committed
515
516
    wraps(fn)

517
    def _wrapper(*args, **kwargs):
Fabrizio Milo's avatar
Fabrizio Milo committed
518
519
520
        if len(args) != 1 if inspect.ismethod(fn) else 0:
            print(
                f"WARNING: using {fn.__name__} with positional arguments is "
521
                "deprecated and will be disallowed in a future version of "
Fabrizio Milo's avatar
Fabrizio Milo committed
522
523
                "lm-evaluation-harness!"
            )
524
        return fn(*args, **kwargs)
Fabrizio Milo's avatar
Fabrizio Milo committed
525

526
    return _wrapper
Stephen Hogg's avatar
Stephen Hogg committed
527

Fabrizio Milo's avatar
Fabrizio Milo committed
528

529
530
531
532
def ignore_constructor(loader, node):
    return node


533
def import_function(loader: yaml.Loader, node, yaml_path: Path):
lintangsutawika's avatar
lintangsutawika committed
534
535
    function_name = loader.construct_scalar(node)

lintangsutawika's avatar
lintangsutawika committed
536
    *module_name, function_name = function_name.split(".")
537
    if isinstance(module_name, list):
lintangsutawika's avatar
lintangsutawika committed
538
        module_name = ".".join(module_name)
539
    module_path = yaml_path.parent / f"{module_name}.py"
lintangsutawika's avatar
lintangsutawika committed
540

541
542
543
544
    spec = importlib.util.spec_from_file_location(module_name, module_path.as_posix())

    if spec is None:
        raise ImportError(f"Could not import module {module_name} from {module_path}.")
lintangsutawika's avatar
lintangsutawika committed
545
    module = importlib.util.module_from_spec(spec)
546
547
548

    if spec.loader is None:
        raise ImportError(f"Module loader is None, {module_name} from {module_path}.")
lintangsutawika's avatar
lintangsutawika committed
549
550
551
552
553
    spec.loader.exec_module(module)

    function = getattr(module, function_name)
    return function

lintangsutawika's avatar
lintangsutawika committed
554

Baber's avatar
Baber committed
555
556
557
def load_yaml_config(
    yaml_path: str | None = None, yaml_config=None, yaml_dir=None, mode="full"
):
558
559
560
    if mode == "simple":
        constructor_fn = ignore_constructor
    elif mode == "full":
561
562
563
        if yaml_path is None:
            raise ValueError("yaml_path must be provided if mode is 'full'.")
        # Attach yaml_path to the import function so that it can be used later
Baber's avatar
Baber committed
564
        constructor_fn = partial(import_function, yaml_path=Path(yaml_path))
lintangsutawika's avatar
lintangsutawika committed
565

566
    loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
567
    # Add the import_function constructor to the YAML loader
568
    yaml.add_constructor("!function", constructor_fn, Loader=loader)
569
570
    if yaml_config is None:
        with open(yaml_path, "rb") as file:
571
            yaml_config = yaml.load(file, Loader=loader)
lintangsutawika's avatar
lintangsutawika committed
572

lintangsutawika's avatar
lintangsutawika committed
573
574
    if yaml_dir is None:
        yaml_dir = os.path.dirname(yaml_path)
575
576
577
578
579
580
581

    assert yaml_dir is not None

    if "include" in yaml_config:
        include_path = yaml_config["include"]
        del yaml_config["include"]

582
        if isinstance(include_path, str):
583
584
585
586
587
588
589
590
591
592
593
594
595
            include_path = [include_path]

        # Load from the last one first
        include_path.reverse()
        final_yaml_config = {}
        for path in include_path:
            # Assumes that path is a full path.
            # If not found, assume the included yaml
            # is in the same dir as the original yaml
            if not os.path.isfile(path):
                path = os.path.join(yaml_dir, path)

            try:
596
                included_yaml_config = load_yaml_config(yaml_path=path, mode=mode)
597
598
599
600
601
602
603
604
                final_yaml_config.update(included_yaml_config)
            except Exception as ex:
                # If failed to load, ignore
                raise ex

        final_yaml_config.update(yaml_config)
        return final_yaml_config
    return yaml_config
lintangsutawika's avatar
lintangsutawika committed
605
606


Ethan Smith's avatar
Ethan Smith committed
607
def regex_replace(string, pattern, repl, count: int = 0):
608
609
    """Implements the `re.sub` function as a custom Jinja filter."""
    return re.sub(pattern, repl, string, count=count)
lintangsutawika's avatar
lintangsutawika committed
610

lintangsutawika's avatar
lintangsutawika committed
611

612
env = Environment(
Baber's avatar
Baber committed
613
    loader=BaseLoader(), undefined=StrictUndefined, keep_trailing_newline=True
614
)
615
env.filters["regex_replace"] = regex_replace
616
617


Baber's avatar
Baber committed
618
@lru_cache(maxsize=128)
Baber's avatar
Baber committed
619
def _compile(raw: str) -> Template:
Baber's avatar
Baber committed
620
621
622
    return env.from_string(raw)


baberabb's avatar
baberabb committed
623
def apply_template(template: str, doc: dict) -> str:
Baber's avatar
Baber committed
624
    rtemplate = _compile(template)
625
    return rtemplate.render(**doc)
626
627


Baber's avatar
Baber committed
628
629
630
631
632
633
634
def create_iterator(
    raw_iterator: collections.Iterator,
    *,
    rank: int = 0,
    world_size: int = 1,
    limit: int | None = None,
) -> islice:
635
636
637
    """
    Method for creating a (potentially) sliced and limited
    iterator from a raw document iterator. Used for splitting data
638
639
640
    among ranks in multigpu setting or only pulling a sample of documents
    """
    return islice(raw_iterator, rank, limit, world_size)
641
642


artemorloff's avatar
artemorloff committed
643
# TODO: why func for metric calc is here in eval utils?
644
645
646
647
648
649
650
651
def weighted_f1_score(items):
    from sklearn.metrics import f1_score

    unzipped_list = list(zip(*items))
    golds = unzipped_list[0]
    preds = unzipped_list[1]
    fscore = f1_score(golds, preds, average="weighted")
    return fscore
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680


def convert_pil_to_hash(value):
    from io import BytesIO

    img_bytes = BytesIO()
    value.save(img_bytes, format="PNG")
    return hashlib.sha256(str(img_bytes).encode()).hexdigest()


def convert_bytes_to_hash(value):
    return hashlib.sha256(str(value).encode()).hexdigest()


def hash_dict_images(data_dict):
    """
    Create a deep copy of `data_dict` where all bytes and PIL.Image.Image values
    are replaced by their respective hashes using the provided converter functions.

    Parameters:
        data_dict (dict): The input dictionary with arbitrary nesting of dicts and lists.

    Returns:
        dict: A new dictionary with the same structure as `data_dict`, but with all
              bytes and PIL.Image.Image objects replaced by their hashes.
    """

    def _process_value(value):
        # Bytes -> hash
Baber Abbasi's avatar
Baber Abbasi committed
681
682
        from PIL import Image

683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
        if isinstance(value, (bytes, bytearray)):
            return convert_bytes_to_hash(value)
        # PIL Image -> hash
        if isinstance(value, Image.Image):
            return convert_pil_to_hash(value)
        # Nested dictionary -> recurse
        if isinstance(value, dict):
            return {k: _process_value(v) for k, v in value.items()}
        # List or tuple -> recurse, preserving type
        if isinstance(value, list):
            return [_process_value(v) for v in value]
        if isinstance(value, tuple):
            return tuple(_process_value(v) for v in value)
        # Other types remain unchanged
        return value

    # Ensure the top-level is a dict
    if not isinstance(data_dict, dict):
        raise TypeError("Input must be a dictionary")

Baber Abbasi's avatar
Baber Abbasi committed
703
704
705
706
707
    return (
        {key: _process_value(val) for key, val in data_dict.items()}
        if importlib.util.find_spec("PIL")
        else data_dict
    )