utils.py 25.4 KB
Newer Older
1
2
3
import collections
import fnmatch
import functools
4
import hashlib
5
import importlib.util
6
import inspect
7
import json
8
9
10
import logging
import os
import re
11
import threading
12
from dataclasses import asdict, is_dataclass
13
from itertools import islice
14
from pathlib import Path
Baber Abbasi's avatar
Baber Abbasi committed
15
from typing import Any, Callable, Generator, List, Optional, Tuple
16

Lintang Sutawika's avatar
Lintang Sutawika committed
17
import numpy as np
18
import requests
19
import yaml
20
from jinja2 import BaseLoader, Environment, StrictUndefined
sdtblck's avatar
sdtblck committed
21

lintangsutawika's avatar
lintangsutawika committed
22

23
SPACING = " " * 47
sdtblck's avatar
sdtblck committed
24

25
26
27
28
29
HIGHER_IS_BETTER_SYMBOLS = {
    True: "↑",
    False: "↓",
}

sdtblck's avatar
sdtblck committed
30

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def wrap_text(string: str, width: int = 140, **kwargs) -> Optional[str]:
    """
    Wraps the given string to the specified width.
    """
    import textwrap

    return textwrap.fill(
        inspect.cleandoc(string),
        width=width,
        initial_indent="",
        subsequent_indent=" " * 8,
        break_long_words=False,
        break_on_hyphens=False,
        **kwargs,
    )


Lintang Sutawika's avatar
Lintang Sutawika committed
48
49
def setup_logging(verbosity=logging.INFO):
    # Configure the root logger
Baber Abbasi's avatar
Baber Abbasi committed
50
51
52
53
54
55
56
57
58
59
60
    class CustomFormatter(logging.Formatter):
        def format(self, record):
            if record.name.startswith("lm_eval."):
                record.name = record.name[len("lm_eval.") :]
            return super().format(record)

    formatter = CustomFormatter(
        "%(asctime)s %(levelname)-8s [%(name)s:%(lineno)d] %(message)s",
        datefmt="%Y-%m-%d:%H:%M:%S",
    )

Lintang Sutawika's avatar
Lintang Sutawika committed
61
62
63
64
65
66
67
68
69
70
71
    log_level = os.environ.get("LOGLEVEL", verbosity) or verbosity

    level_map = {
        "DEBUG": logging.DEBUG,
        "INFO": logging.INFO,
        "WARNING": logging.WARNING,
        "ERROR": logging.ERROR,
        "CRITICAL": logging.CRITICAL,
    }

    log_level = level_map.get(str(log_level).upper(), logging.INFO)
Baber Abbasi's avatar
Baber Abbasi committed
72

Lintang Sutawika's avatar
Lintang Sutawika committed
73
    if not logging.root.handlers:
Baber Abbasi's avatar
Baber Abbasi committed
74
75
76
77
78
79
80
        handler = logging.StreamHandler()
        handler.setFormatter(formatter)

        root_logger = logging.getLogger()
        root_logger.addHandler(handler)
        root_logger.setLevel(log_level)

Lintang Sutawika's avatar
Lintang Sutawika committed
81
82
83
84
85
86
87
88
        if log_level == logging.DEBUG:
            third_party_loggers = ["urllib3", "filelock", "fsspec"]
            for logger_name in third_party_loggers:
                logging.getLogger(logger_name).setLevel(logging.INFO)
    else:
        logging.getLogger().setLevel(log_level)


89
90
91
92
def hash_string(string: str) -> str:
    return hashlib.sha256(string.encode("utf-8")).hexdigest()


93
94
95
96
97
98
99
100
101
102
103
104
def escaped_split(text, sep_char, maxsplit=-1):
    """Split text into a list on occurrences of the given separation
    character `sep_char`. The separation character may be escaped by a
    backslash to avoid splitting at that location.

    The separation character must be a string of size 1.

    If `maxsplit` is given, at most `maxsplit` splits are done (thus,
    the list will have at most `maxsplit + 1` elements). If `maxsplit`
    is not specified or less than 0, then there is no limit on the
    number of splits (all possible splits are made).
    """
Baber Abbasi's avatar
Baber Abbasi committed
105
106
107
    assert len(sep_char) == 1, (
        "separation string must be a single character for escaped splitting"
    )
108
109
110
111
112
113
114
115

    if maxsplit == 0:
        return text
    maxsplit = max(0, maxsplit)

    return re.split(r"(?<!\\)" + sep_char, text, maxsplit)


haileyschoelkopf's avatar
haileyschoelkopf committed
116
117
118
119
120
def handle_arg_string(arg):
    if arg.lower() == "true":
        return True
    elif arg.lower() == "false":
        return False
121
122
123
124
125
126
    elif arg.isnumeric():
        return int(arg)
    try:
        return float(arg)
    except ValueError:
        return arg
haileyschoelkopf's avatar
haileyschoelkopf committed
127
128


129
130
131
132
133
134
135
136
137
def handle_non_serializable(o):
    if isinstance(o, np.int64) or isinstance(o, np.int32):
        return int(o)
    elif isinstance(o, set):
        return list(o)
    else:
        return str(o)


138
139
140
141
142
143
144
145
146
147
148
149
def sanitize_list(sub):
    """
    Takes possible nested list and recursively converts all inner component to strings
    """
    if isinstance(sub, list):
        return [sanitize_list(item) for item in sub]
    if isinstance(sub, tuple):
        return tuple(sanitize_list(item) for item in sub)
    else:
        return str(sub)


Baber Abbasi's avatar
Baber Abbasi committed
150
def simple_parse_args_string(args_string: Optional[str]) -> dict:
Jason Phang's avatar
gpt3  
Jason Phang committed
151
152
153
154
155
    """
    Parses something like
        args1=val1,arg2=val2
    Into a dictionary
    """
Baber Abbasi's avatar
Baber Abbasi committed
156
157
    if args_string is None:
        return {}
Jason Phang's avatar
Jason Phang committed
158
    args_string = args_string.strip()
Jason Phang's avatar
gpt3  
Jason Phang committed
159
160
    if not args_string:
        return {}
161
    arg_list = [arg for arg in args_string.split(",") if arg]
haileyschoelkopf's avatar
haileyschoelkopf committed
162
    args_dict = {
163
164
        kv[0]: handle_arg_string("=".join(kv[1:]))
        for kv in [arg.split("=") for arg in arg_list]
haileyschoelkopf's avatar
haileyschoelkopf committed
165
    }
Jason Phang's avatar
gpt3  
Jason Phang committed
166
    return args_dict
Leo Gao's avatar
Leo Gao committed
167

Fabrizio Milo's avatar
Fabrizio Milo committed
168

Leo Gao's avatar
Leo Gao committed
169
170
def join_iters(iters):
    for iter in iters:
Leo Gao's avatar
Leo Gao committed
171
        yield from iter
Leo Gao's avatar
Leo Gao committed
172
173


174
175
176
177
178
def group(arr, fn):
    res = collections.defaultdict(list)

    for ob in arr:
        res[fn(ob)].append(ob)
Fabrizio Milo's avatar
Fabrizio Milo committed
179

180
181
    return list(res.values())

Fabrizio Milo's avatar
Fabrizio Milo committed
182

gakada's avatar
gakada committed
183
184
185
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
186
    if isinstance(patterns, str):
187
188
        patterns = [patterns]

gakada's avatar
gakada committed
189
190
191
192
193
194
195
    task_names = set()
    for pattern in patterns:
        for matching in fnmatch.filter(source_list, pattern):
            task_names.add(matching)
    return sorted(list(task_names))


Baber Abbasi's avatar
Baber Abbasi committed
196
def softmax(x) -> np.ndarray:
Lintang Sutawika's avatar
Lintang Sutawika committed
197
198
199
200
201
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()


Baber Abbasi's avatar
Baber Abbasi committed
202
def general_detokenize(string) -> str:
Leo Gao's avatar
Leo Gao committed
203
204
205
    string = string.replace(" n't", "n't")
    string = string.replace(" )", ")")
    string = string.replace("( ", "(")
Fabrizio Milo's avatar
Fabrizio Milo committed
206
207
    string = string.replace('" ', '"')
    string = string.replace(' "', '"')
Leo Gao's avatar
Fix  
Leo Gao committed
208
    string = re.sub(r" (['.,])", r"\1", string)
209
210
211
    return string


212
213
214
215
216
217
218
219
220
221
222
def get_file_task_name(filename: str) -> str:
    """
    Given the sample results filenames, extracts and returns the task name.
    """
    return filename[filename.find("_") + 1 : filename.rfind("_")]


def get_file_datetime(filename: str) -> str:
    """
    Given the results and sample results filenames, extracts and returns the datetime.
    """
223
    return filename[filename.rfind("_") + 1 :].replace(".jsonl", "")
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260


def sanitize_model_name(model_name: str) -> str:
    """
    Given the model name, returns a sanitized version of it.
    """
    return re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", model_name)


def sanitize_task_name(task_name: str) -> str:
    """
    Given the task name, returns a sanitized version of it.
    """
    return re.sub(r"\W", "_", task_name)


def get_latest_filename(filenames: List[str]) -> str:
    """
    Given a list of filenames, returns the filename with the latest datetime.
    """
    return max(filenames, key=lambda f: get_file_datetime(f))


def get_results_filenames(filenames: List[str]) -> List[str]:
    """
    Extracts filenames that correspond to aggregated results.
    """
    return [f for f in filenames if "/results_" in f and ".json" in f]


def get_sample_results_filenames(filenames: List[str]) -> List[str]:
    """
    Extracts filenames that correspond to sample results.
    """
    return [f for f in filenames if "/samples_" in f and ".json" in f]


261
262
263
def get_rolling_token_windows(
    token_list: List[int], prefix_token: int, max_seq_len: int, context_len: int
) -> Generator[Tuple[List[int], List[int]], None, None]:
Jason Phang's avatar
Jason Phang committed
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
    """
    - context_len allows for a rolling window context, allowing each prediction window to potentially
      condition on some context

    :param token_list: list
        List of tokens to be PREDICTED
    :param max_seq_len: int
        max_seq_len of model (or max_seq_len we want to use)
    :param context_len: int
        Amount of desired token context for prediction. Needs to be at least 1.
    :param prefix_token: token
        Dummy token like <eos> so the first token has something to condition on
    :return: generator
        Generator of tuples
            (input_tokens, pred_tokens)
        Note: Score only the last len(pred_tokens) logits of the LM
    """
    assert 1 <= context_len <= max_seq_len
    if not token_list:
        return
    # +1 offset, going from input->preds
    pred_len = max_seq_len - context_len + 1
    predicted = 0

    # Special handling for first window: predict all tokens
    first_seq_len = min(max_seq_len, len(token_list))
290
    yield [prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len]
Jason Phang's avatar
Jason Phang committed
291
292
293
294
295
    predicted += first_seq_len

    while predicted < len(token_list):
        window_pred_len = min(len(token_list) - predicted, pred_len)
        window_end = predicted + window_pred_len
Leo Gao's avatar
Leo Gao committed
296

Jason Phang's avatar
Jason Phang committed
297
        yield (
lintangsutawika's avatar
lintangsutawika committed
298
299
            token_list[window_end - max_seq_len - 1 : window_end - 1],
            token_list[window_end - window_pred_len : window_end],
Jason Phang's avatar
Jason Phang committed
300
301
302
        )
        predicted += window_pred_len

Fabrizio Milo's avatar
Fabrizio Milo committed
303

304
305
306
def make_disjoint_window(
    pair: Tuple[List[int], List[int]],
) -> Tuple[List[int], List[int]]:
Fabrizio Milo's avatar
Fabrizio Milo committed
307
    """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
Leo Gao's avatar
Leo Gao committed
308
    a, b = pair
309
    return a[: len(a) - (len(b) - 1)], b
Fabrizio Milo's avatar
Fabrizio Milo committed
310

Jason Phang's avatar
Jason Phang committed
311

312
313
314
315
316
317
318
319
320
321
322
323
class EnhancedJSONEncoder(json.JSONEncoder):
    """
    Provides a proper json encoding for the loggers and trackers json dumps.
    Notably manages the json encoding of dataclasses.
    """

    def default(self, o):
        if is_dataclass(o):
            return asdict(o)
        return super().default(o)


324
class Reorderer:
baberabb's avatar
baberabb committed
325
326
327
328
329
330
331
    def __init__(self, arr: List[Any], fn: Callable) -> None:
        """Reorder an array according to some function

        Args:
            arr (List[Any]): The initial array
            fn (Callable[[Any], Any]): A function to determine the priority of elements
        """
332
333
334
        self.size = len(arr)
        arr = list(enumerate(arr))
        arr = group(arr, lambda x: fn(x[1]))
335
336
337
        # arr = [([y[0] for y in x], x[0][1]) for x in arr]
        # TODO: overhaul reorderer. It currently grouped requests by content but we don't want this
        arr = [([y[0]], x[0][1]) for x in arr for y in x]
338
339
340
        arr.sort(key=lambda x: fn(x[1]))

        self.arr = arr
Fabrizio Milo's avatar
Fabrizio Milo committed
341

342
    def get_reordered(self):
baberabb's avatar
baberabb committed
343
344
345
346
347
        """Gets the reordered array

        Returns:
            List[Any]: The reordered array
        """
348
        return [x[1] for x in self.arr]
Fabrizio Milo's avatar
Fabrizio Milo committed
349

350
    def get_original(self, newarr):
baberabb's avatar
baberabb committed
351
352
353
354
355
356
357
358
        """Restores the original order of a new array based on the old array's order

        Args:
            newarr (List[Any]): The array to be restored

        Returns:
            List[Any]: The array restored to the original order
        """
359
360
361
362
        res = [None] * self.size
        cov = [False] * self.size

        for (inds, _), v in zip(self.arr, newarr):
Fabrizio Milo's avatar
Fabrizio Milo committed
363
            for ind in inds:
364
365
                res[ind] = v
                cov[ind] = True
Fabrizio Milo's avatar
Fabrizio Milo committed
366

367
        assert all(cov)
Fabrizio Milo's avatar
Fabrizio Milo committed
368

369
370
        return res

Fabrizio Milo's avatar
Fabrizio Milo committed
371

Lintang Sutawika's avatar
Lintang Sutawika committed
372
def make_table(result_dict, column: str = "results", sort_results: bool = False):
373
    """Generate table of results."""
374
    from pytablewriter import LatexTableWriter, MarkdownTableWriter
375

lintangsutawika's avatar
lintangsutawika committed
376
    if column == "results":
lintangsutawika's avatar
lintangsutawika committed
377
378
379
        column_name = "Tasks"
    elif column == "groups":
        column_name = "Groups"
lintangsutawika's avatar
lintangsutawika committed
380

lintangsutawika's avatar
lintangsutawika committed
381
    all_headers = [
lintangsutawika's avatar
lintangsutawika committed
382
        column_name,
lintangsutawika's avatar
lintangsutawika committed
383
384
        "Version",
        "Filter",
385
        "n-shot",
lintangsutawika's avatar
lintangsutawika committed
386
        "Metric",
387
        "",
lintangsutawika's avatar
lintangsutawika committed
388
389
390
391
        "Value",
        "",
        "Stderr",
    ]
392

lintangsutawika's avatar
lintangsutawika committed
393
394
395
396
397
    md_writer = MarkdownTableWriter()
    latex_writer = LatexTableWriter()
    md_writer.headers = all_headers
    latex_writer.headers = all_headers

398
399
    values = []

400
401
    keys = result_dict[column].keys()
    if sort_results:
Lintang Sutawika's avatar
Lintang Sutawika committed
402
403
404
        # sort entries alphabetically by task or group name.
        # NOTE: we default here to false, because order matters for multi-level table printing a la mmlu.
        # sorting here would mess that up
405
406
407
        keys = sorted(keys)
    for k in keys:
        dic = result_dict[column][k]
Lintang Sutawika's avatar
Lintang Sutawika committed
408
409
        version = result_dict["versions"].get(k, "    N/A")
        n = str(result_dict.get("n-shot", " ").get(k, " "))
410
        higher_is_better = result_dict.get("higher_is_better", {}).get(k, {})
411
412
413
414

        if "alias" in dic:
            k = dic.pop("alias")

415
        metric_items = dic.items()
Lintang Sutawika's avatar
Lintang Sutawika committed
416
        metric_items = sorted(metric_items)
417
418

        for (mf), v in metric_items:
419
            m, _, f = mf.partition(",")
420
421
422
            if m.endswith("_stderr"):
                continue

423
424
            hib = HIGHER_IS_BETTER_SYMBOLS.get(higher_is_better.get(m), "")

Lintang Sutawika's avatar
Lintang Sutawika committed
425
426
            v = "%.4f" % v if isinstance(v, float) else v

427
428
            if m + "_stderr" + "," + f in dic:
                se = dic[m + "_stderr" + "," + f]
Lintang Sutawika's avatar
Lintang Sutawika committed
429
                se = "   N/A" if se == "N/A" else "%.4f" % se
Lintang Sutawika's avatar
Lintang Sutawika committed
430
                values.append([k, version, f, n, m, hib, v, "±", se])
431
            else:
Lintang Sutawika's avatar
Lintang Sutawika committed
432
                values.append([k, version, f, n, m, hib, v, "", ""])
433
434
435
436
437
438
439
440
441
442
443
            k = ""
            version = ""
    md_writer.value_matrix = values
    latex_writer.value_matrix = values

    # todo: make latex table look good
    # print(latex_writer.dumps())

    return md_writer.dumps()


444
445
def positional_deprecated(fn):
    """
Fabrizio Milo's avatar
Fabrizio Milo committed
446
    A decorator to nudge users into passing only keyword args (`kwargs`) to the
447
448
    wrapped function, `fn`.
    """
Fabrizio Milo's avatar
Fabrizio Milo committed
449

450
451
    @functools.wraps(fn)
    def _wrapper(*args, **kwargs):
Fabrizio Milo's avatar
Fabrizio Milo committed
452
453
454
        if len(args) != 1 if inspect.ismethod(fn) else 0:
            print(
                f"WARNING: using {fn.__name__} with positional arguments is "
455
                "deprecated and will be disallowed in a future version of "
Fabrizio Milo's avatar
Fabrizio Milo committed
456
457
                "lm-evaluation-harness!"
            )
458
        return fn(*args, **kwargs)
Fabrizio Milo's avatar
Fabrizio Milo committed
459

460
    return _wrapper
Stephen Hogg's avatar
Stephen Hogg committed
461

Fabrizio Milo's avatar
Fabrizio Milo committed
462

463
464
465
466
def ignore_constructor(loader, node):
    return node


467
def import_function(loader: yaml.Loader, node, yaml_path: Path):
lintangsutawika's avatar
lintangsutawika committed
468
469
    function_name = loader.construct_scalar(node)

lintangsutawika's avatar
lintangsutawika committed
470
    *module_name, function_name = function_name.split(".")
471
    if isinstance(module_name, list):
lintangsutawika's avatar
lintangsutawika committed
472
        module_name = ".".join(module_name)
473
    module_path = yaml_path.parent / f"{module_name}.py"
lintangsutawika's avatar
lintangsutawika committed
474

475
476
477
478
    spec = importlib.util.spec_from_file_location(module_name, module_path.as_posix())

    if spec is None:
        raise ImportError(f"Could not import module {module_name} from {module_path}.")
lintangsutawika's avatar
lintangsutawika committed
479
    module = importlib.util.module_from_spec(spec)
480
481
482

    if spec.loader is None:
        raise ImportError(f"Module loader is None, {module_name} from {module_path}.")
lintangsutawika's avatar
lintangsutawika committed
483
484
485
486
487
    spec.loader.exec_module(module)

    function = getattr(module, function_name)
    return function

lintangsutawika's avatar
lintangsutawika committed
488

489
490
491
492
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full"):
    if mode == "simple":
        constructor_fn = ignore_constructor
    elif mode == "full":
493
494
495
496
        if yaml_path is None:
            raise ValueError("yaml_path must be provided if mode is 'full'.")
        # Attach yaml_path to the import function so that it can be used later
        constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path))
lintangsutawika's avatar
lintangsutawika committed
497

498
    loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
499
    # Add the import_function constructor to the YAML loader
500
    yaml.add_constructor("!function", constructor_fn, Loader=loader)
501
502
    if yaml_config is None:
        with open(yaml_path, "rb") as file:
503
            yaml_config = yaml.load(file, Loader=loader)
lintangsutawika's avatar
lintangsutawika committed
504

lintangsutawika's avatar
lintangsutawika committed
505
506
    if yaml_dir is None:
        yaml_dir = os.path.dirname(yaml_path)
507
508
509
510
511
512
513

    assert yaml_dir is not None

    if "include" in yaml_config:
        include_path = yaml_config["include"]
        del yaml_config["include"]

514
        if isinstance(include_path, str):
515
516
517
518
519
520
521
522
523
524
525
526
527
            include_path = [include_path]

        # Load from the last one first
        include_path.reverse()
        final_yaml_config = {}
        for path in include_path:
            # Assumes that path is a full path.
            # If not found, assume the included yaml
            # is in the same dir as the original yaml
            if not os.path.isfile(path):
                path = os.path.join(yaml_dir, path)

            try:
528
                included_yaml_config = load_yaml_config(yaml_path=path, mode=mode)
529
530
531
532
533
534
535
536
                final_yaml_config.update(included_yaml_config)
            except Exception as ex:
                # If failed to load, ignore
                raise ex

        final_yaml_config.update(yaml_config)
        return final_yaml_config
    return yaml_config
lintangsutawika's avatar
lintangsutawika committed
537
538


Ethan Smith's avatar
Ethan Smith committed
539
def regex_replace(string, pattern, repl, count: int = 0):
540
541
    """Implements the `re.sub` function as a custom Jinja filter."""
    return re.sub(pattern, repl, string, count=count)
lintangsutawika's avatar
lintangsutawika committed
542

lintangsutawika's avatar
lintangsutawika committed
543

544
545
546
env = Environment(
    loader=BaseLoader, undefined=StrictUndefined, keep_trailing_newline=True
)
547
env.filters["regex_replace"] = regex_replace
548
549


baberabb's avatar
baberabb committed
550
def apply_template(template: str, doc: dict) -> str:
551
552
    rtemplate = env.from_string(template)
    return rtemplate.render(**doc)
553
554


555
def create_iterator(raw_iterator, *, rank=0, world_size=1, limit=None):
556
557
558
    """
    Method for creating a (potentially) sliced and limited
    iterator from a raw document iterator. Used for splitting data
559
560
561
    among ranks in multigpu setting or only pulling a sample of documents
    """
    return islice(raw_iterator, rank, limit, world_size)
562
563
564
565
566
567
568
569
570
571


def weighted_f1_score(items):
    from sklearn.metrics import f1_score

    unzipped_list = list(zip(*items))
    golds = unzipped_list[0]
    preds = unzipped_list[1]
    fscore = f1_score(golds, preds, average="weighted")
    return fscore
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600


def convert_pil_to_hash(value):
    from io import BytesIO

    img_bytes = BytesIO()
    value.save(img_bytes, format="PNG")
    return hashlib.sha256(str(img_bytes).encode()).hexdigest()


def convert_bytes_to_hash(value):
    return hashlib.sha256(str(value).encode()).hexdigest()


def hash_dict_images(data_dict):
    """
    Create a deep copy of `data_dict` where all bytes and PIL.Image.Image values
    are replaced by their respective hashes using the provided converter functions.

    Parameters:
        data_dict (dict): The input dictionary with arbitrary nesting of dicts and lists.

    Returns:
        dict: A new dictionary with the same structure as `data_dict`, but with all
              bytes and PIL.Image.Image objects replaced by their hashes.
    """

    def _process_value(value):
        # Bytes -> hash
Baber Abbasi's avatar
Baber Abbasi committed
601
602
        from PIL import Image

603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
        if isinstance(value, (bytes, bytearray)):
            return convert_bytes_to_hash(value)
        # PIL Image -> hash
        if isinstance(value, Image.Image):
            return convert_pil_to_hash(value)
        # Nested dictionary -> recurse
        if isinstance(value, dict):
            return {k: _process_value(v) for k, v in value.items()}
        # List or tuple -> recurse, preserving type
        if isinstance(value, list):
            return [_process_value(v) for v in value]
        if isinstance(value, tuple):
            return tuple(_process_value(v) for v in value)
        # Other types remain unchanged
        return value

    # Ensure the top-level is a dict
    if not isinstance(data_dict, dict):
        raise TypeError("Input must be a dictionary")

Baber Abbasi's avatar
Baber Abbasi committed
623
624
625
626
627
    return (
        {key: _process_value(val) for key, val in data_dict.items()}
        if importlib.util.find_spec("PIL")
        else data_dict
    )
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842


class RemoteTokenizer:
    """
    Minimal robust tokenizer that uses vLLM server's tokenizer endpoints.
    """

    def __init__(
        self,
        base_url: str,
        timeout: int = 30,
        verify_certificate: bool = True,
        ca_cert_path: Optional[str] = None,
        auth_token: Optional[str] = None,
        max_retries: int = 3,
    ):
        self.timeout = timeout
        self.max_retries = max_retries
        self._lock = threading.RLock()
        self._tokenizer_info = None
        self._chat_template_obj = None

        # Certificate logic
        self.cert_config = (
            ca_cert_path if verify_certificate and ca_cert_path else verify_certificate
        )

        # Auth header logic
        self.headers = {"Content-Type": "application/json"}
        if auth_token:
            self.headers["Authorization"] = f"Bearer {auth_token}"

        # Normalize base URL - remove API endpoints to get server base
        self.base_url = (
            base_url.replace("/v1/completions", "")
            .replace("/v1/chat/completions", "")
            .rstrip("/")
        )

        # Use a session for connection pooling
        self.session = requests.Session()
        self.session.headers.update(self.headers)

        # Validate server supports tokenizer_info endpoint
        self._validate_server()

    def _request_with_retries(self, method, url, **kwargs):
        last_exc = None
        for _ in range(self.max_retries):
            try:
                resp = self.session.request(
                    method,
                    url,
                    timeout=kwargs.pop("timeout", self.timeout),
                    verify=self.cert_config,
                    **kwargs,
                )
                resp.raise_for_status()
                return resp
            except requests.RequestException as e:
                last_exc = e
        raise RuntimeError(
            f"RemoteTokenizer: {method} {url} failed after {self.max_retries} attempts: {last_exc}"
        )

    def _validate_server(self):
        url = f"{self.base_url}/tokenizer_info"
        resp = self._request_with_retries("GET", url)
        if resp.status_code != 200:
            raise RuntimeError(
                f"Server does not support tokenizer_info endpoint. Status: {resp.status_code}"
            )

    @property
    def tokenizer_info(self) -> dict:
        with self._lock:
            if self._tokenizer_info is None:
                url = f"{self.base_url}/tokenizer_info"
                resp = self._request_with_retries("GET", url)
                self._tokenizer_info = resp.json()
            return self._tokenizer_info

    @property
    def eos_token(self) -> Optional[str]:
        return self.tokenizer_info.get("eos_token")

    @property
    def bos_token(self) -> Optional[str]:
        return self.tokenizer_info.get("bos_token")

    @property
    def pad_token(self) -> Optional[str]:
        return self.tokenizer_info.get("pad_token")

    @property
    def eos_token_id(self) -> Optional[int]:
        if self.eos_token is None:
            return None
        return self.encode(self.eos_token)[0]

    @property
    def bos_token_id(self) -> Optional[int]:
        if self.bos_token is None:
            return None
        return self.encode(self.bos_token)[0]

    @property
    def eot_token(self) -> Optional[int]:
        return self.eos_token_id

    def encode(self, text: str) -> List[int]:
        url = f"{self.base_url}/tokenize"
        payload = {"prompt": text, "add_special_tokens": False}
        resp = self._request_with_retries("POST", url, json=payload)
        tokens = resp.json().get("tokens")
        if not isinstance(tokens, list):
            raise RuntimeError("Malformed response from /tokenize endpoint.")
        return tokens

    def decode(self, tokens: List[int]) -> str:
        url = f"{self.base_url}/detokenize"
        payload = {"tokens": tokens}
        resp = self._request_with_retries("POST", url, json=payload)
        prompt = resp.json().get("prompt")
        if not isinstance(prompt, str):
            raise RuntimeError("Malformed response from /detokenize endpoint.")
        return prompt

    def batch_decode(self, tokens_list: List[List[int]]) -> List[str]:
        return [self.decode(tokens) for tokens in tokens_list]

    def apply_chat_template(
        self, chat_history: list, add_generation_prompt: bool = True, **kwargs
    ) -> str:
        with self._lock:
            if self._chat_template_obj is None:
                template_str = self.tokenizer_info.get("chat_template")
                if not template_str:
                    raise ValueError("No chat template available from server")
                self._chat_template_obj = env.from_string(template_str)
        return self._chat_template_obj.render(
            messages=chat_history, add_generation_prompt=add_generation_prompt, **kwargs
        )

    def __call__(self, text: str, add_special_tokens: bool = False, **kwargs) -> dict:
        tokens = self.encode(text)
        return {"input_ids": tokens}


def check_remote_tokenizer_support(
    base_url: str,
    timeout: int = 5,
    verify_certificate: bool = True,
    ca_cert_path: Optional[str] = None,
    auth_token: Optional[str] = None,
    max_retries: int = 3,
) -> bool:
    """
    Check if server supports remote tokenizer endpoints.
    Returns True if both /tokenizer_info and /tokenize endpoints are available and functional, False otherwise.
    """
    if not base_url:
        return False

    server_base = (
        base_url.replace("/v1/completions", "")
        .replace("/v1/chat/completions", "")
        .rstrip("/")
    )
    cert_config = (
        ca_cert_path if verify_certificate and ca_cert_path else verify_certificate
    )
    headers = {"Content-Type": "application/json"}
    if auth_token:
        headers["Authorization"] = f"Bearer {auth_token}"

    session = requests.Session()
    session.headers.update(headers)

    def _request_with_retries(method, url, **kwargs):
        for _ in range(max_retries):
            try:
                resp = session.request(
                    method,
                    url,
                    timeout=kwargs.pop("timeout", timeout),
                    verify=cert_config,
                    **kwargs,
                )
                resp.raise_for_status()
                return resp
            except requests.RequestException:
                pass
        return None

    # Check /tokenizer_info
    info_url = f"{server_base}/tokenizer_info"
    resp = _request_with_retries("GET", info_url)
    if not resp:
        return False
    info = resp.json()
    if not isinstance(info, dict) or "eos_token" not in info:
        return False

    # Check /tokenize
    tokenize_url = f"{server_base}/tokenize"
    test_payload = {"prompt": "test", "add_special_tokens": False}
    resp = _request_with_retries("POST", tokenize_url, json=test_payload)
    if not resp:
        return False
    tokens = resp.json().get("tokens")
    if not isinstance(tokens, list):
        return False

    return True