utils.py 3.59 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import ast
import json
import logging
from typing import Any, Optional, Union


def try_parse_json(value: Union[str, dict, None]) -> Union[str, dict, None]:
    """Try to parse a string as JSON. If it fails, return the original string."""
    if value is None:
        return None
    if isinstance(value, dict):
        return value
    try:
        return json.loads(value)
    except json.JSONDecodeError:
        if "{" in value:
            raise ValueError(
                f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings."
            )
        return value


def _int_or_none_list_arg_type(
    min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
) -> list[Union[int, None]]:
    """Parses a string of integers or 'None' values separated by a specified character into a list.
    Validates the number of items against specified minimum and maximum lengths and fills missing values with defaults."""

    def parse_value(item):
        """Parses an individual item, converting it to an integer or `None`."""
        item = item.strip().lower()
        if item == "none":
            return None
        try:
            return int(item)
        except ValueError:
            raise ValueError(f"{item} is not an integer or None")

    items = [parse_value(v) for v in value.split(split_char)]
    num_items = len(items)

    if num_items == 1:
        items = items * max_len
    elif num_items < min_len or num_items > max_len:
        raise ValueError(
            f"Argument requires {max_len} integers or None, separated by '{split_char}'"
        )
    elif num_items != max_len:
        logging.warning(
            f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
            "Missing values will be filled with defaults."
        )
        default_items = [parse_value(v) for v in defaults.split(split_char)]
        items.extend(default_items[num_items:])

    return items


def request_caching_arg_to_dict(cache_requests: Optional[str]) -> dict[str, bool]:
    """Convert a request caching argument to a dictionary."""
    if cache_requests is None:
        return {}
    request_caching_args = {
        "cache_requests": cache_requests in {"true", "refresh"},
        "rewrite_requests_cache": cache_requests == "refresh",
        "delete_requests_cache": cache_requests == "delete",
    }

    return request_caching_args


def check_argument_types(parser: argparse.ArgumentParser) -> None:
    """
    Check to make sure all CLI args are typed, raises error if not
    """
    for action in parser._actions:
        # Skip help, subcommands, and const actions
        if action.dest in ["help", "command"] or action.const is not None:
            continue
        if action.type is None:
            raise ValueError(f"Argument '{action.dest}' doesn't have a type specified.")
        else:
            continue


def handle_cli_value_string(arg: str) -> Any:
    if arg.lower() == "true":
        return True
    elif arg.lower() == "false":
        return False
    elif arg.isnumeric():
        return int(arg)
    try:
        return float(arg)
    except ValueError:
        try:
            return ast.literal_eval(arg)
        except (ValueError, SyntaxError):
            return arg


def key_val_to_dict(args: str) -> dict:
    """Parse model arguments from a string into a dictionary."""
    return (
        {
            k: handle_cli_value_string(v)
            for k, v in (item.split("=") for item in args.split(","))
        }
        if args
        else {}
    )


def merge_dicts(*dicts):
    return {k: v for d in dicts for k, v in d.items()}