utils.py 3.59 KB
Newer Older
Baber's avatar
Baber committed
1
import argparse
2
import ast
Baber's avatar
Baber committed
3
4
import json
import logging
5
from typing import Any, Optional, Union
Baber's avatar
Baber committed
6
7


Baber's avatar
Baber committed
8
def try_parse_json(value: Union[str, dict, None]) -> Union[str, dict, None]:
Baber's avatar
Baber committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    """Try to parse a string as JSON. If it fails, return the original string."""
    if value is None:
        return None
    if isinstance(value, dict):
        return value
    try:
        return json.loads(value)
    except json.JSONDecodeError:
        if "{" in value:
            raise ValueError(
                f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings."
            )
        return value


def _int_or_none_list_arg_type(
    min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
) -> list[Union[int, None]]:
    """Parses a string of integers or 'None' values separated by a specified character into a list.
    Validates the number of items against specified minimum and maximum lengths and fills missing values with defaults."""

    def parse_value(item):
        """Parses an individual item, converting it to an integer or `None`."""
        item = item.strip().lower()
        if item == "none":
            return None
        try:
            return int(item)
        except ValueError:
            raise ValueError(f"{item} is not an integer or None")

    items = [parse_value(v) for v in value.split(split_char)]
    num_items = len(items)

    if num_items == 1:
        items = items * max_len
    elif num_items < min_len or num_items > max_len:
        raise ValueError(
            f"Argument requires {max_len} integers or None, separated by '{split_char}'"
        )
    elif num_items != max_len:
        logging.warning(
            f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
            "Missing values will be filled with defaults."
        )
        default_items = [parse_value(v) for v in defaults.split(split_char)]
        items.extend(default_items[num_items:])

    return items


def request_caching_arg_to_dict(cache_requests: Optional[str]) -> dict[str, bool]:
    """Convert a request caching argument to a dictionary."""
    if cache_requests is None:
        return {}
    request_caching_args = {
        "cache_requests": cache_requests in {"true", "refresh"},
        "rewrite_requests_cache": cache_requests == "refresh",
        "delete_requests_cache": cache_requests == "delete",
    }

    return request_caching_args


Baber's avatar
Baber committed
73
def check_argument_types(parser: argparse.ArgumentParser) -> None:
Baber's avatar
Baber committed
74
75
76
77
78
79
80
81
82
83
84
    """
    Check to make sure all CLI args are typed, raises error if not
    """
    for action in parser._actions:
        # Skip help, subcommands, and const actions
        if action.dest in ["help", "command"] or action.const is not None:
            continue
        if action.type is None:
            raise ValueError(f"Argument '{action.dest}' doesn't have a type specified.")
        else:
            continue
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116


def handle_cli_value_string(arg: str) -> Any:
    if arg.lower() == "true":
        return True
    elif arg.lower() == "false":
        return False
    elif arg.isnumeric():
        return int(arg)
    try:
        return float(arg)
    except ValueError:
        try:
            return ast.literal_eval(arg)
        except (ValueError, SyntaxError):
            return arg


def key_val_to_dict(args: str) -> dict:
    """Parse model arguments from a string into a dictionary."""
    return (
        {
            k: handle_cli_value_string(v)
            for k, v in (item.split("=") for item in args.split(","))
        }
        if args
        else {}
    )


def merge_dicts(*dicts):
    return {k: v for d in dicts for k, v in d.items()}