utils.py 5.07 KB
Newer Older
1
2
3
4
import logging
import os
import re
import subprocess
Baber Abbasi's avatar
Baber Abbasi committed
5
from importlib.metadata import version
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
from torch.utils.collect_env import get_pretty_env_info
from transformers import __version__ as trans_version


logger = logging.getLogger(__name__)


def remove_none_pattern(input_string: str) -> Tuple[str, bool]:
    """Remove the ',none' substring from the input_string if it exists at the end.

    Args:
        input_string (str): The input string from which to remove the ',none' substring.

    Returns:
        Tuple[str, bool]: A tuple containing the modified input_string with the ',none' substring removed
                          and a boolean indicating whether the modification was made (True) or not (False).
    """
    # Define the pattern to match ',none' at the end of the string
    pattern = re.compile(r",none$")

    # Use sub() to replace ',none' with an empty string
    result = re.sub(pattern, "", input_string)

    # check if the input_string changed
    removed = result != input_string

    return result, removed


def _handle_non_serializable(o: Any) -> Union[int, str, list]:
    """Handle non-serializable objects by converting them to serializable types.

    Args:
        o (Any): The object to be handled.

    Returns:
        Union[int, str, list]: The converted object. If the object is of type np.int64 or np.int32,
            it will be converted to int. If the object is of type set, it will be converted
            to a list. Otherwise, it will be converted to str.
    """
    if isinstance(o, np.int64) or isinstance(o, np.int32):
        return int(o)
    elif isinstance(o, set):
        return list(o)
    else:
        return str(o)


def get_commit_from_path(repo_path: Union[Path, str]) -> Optional[str]:
    try:
        git_folder = Path(repo_path, ".git")
        if git_folder.is_file():
            git_folder = Path(
                git_folder.parent,
                git_folder.read_text(encoding="utf-8").split("\n")[0].split(" ")[-1],
            )
        if Path(git_folder, "HEAD").exists():
            head_name = (
                Path(git_folder, "HEAD")
                .read_text(encoding="utf-8")
                .split("\n")[0]
                .split(" ")[-1]
            )
            head_ref = Path(git_folder, head_name)
            git_hash = head_ref.read_text(encoding="utf-8").replace("\n", "")
        else:
            git_hash = None
    except Exception as err:
        logger.debug(
            f"Failed to retrieve a Git commit hash from path: {str(repo_path)}. Error: {err}"
        )
        return None
    return git_hash


def get_git_commit_hash():
    """
    Gets the git commit hash of your current repo (if it exists).
    Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42
    """
    try:
        git_hash = subprocess.check_output(["git", "describe", "--always"]).strip()
        git_hash = git_hash.decode()
    except (subprocess.CalledProcessError, FileNotFoundError):
        # FileNotFoundError occurs when git not installed on system
        git_hash = get_commit_from_path(os.getcwd())  # git hash of repo if exists
    return git_hash


def add_env_info(storage: Dict[str, Any]):
    try:
        pretty_env_info = get_pretty_env_info()
    except Exception as err:
        pretty_env_info = str(err)
Baber Abbasi's avatar
Baber Abbasi committed
104
105
106
107
    try:
        lm_eval_version = version("lm_eval")
    except Exception as err:
        lm_eval_version = str(err)
108
109
110
111
112
113
114
    transformers_version = trans_version
    upper_dir_commit = get_commit_from_path(
        Path(os.getcwd(), "..")
    )  # git hash of upper repo if exists
    added_info = {
        "pretty_env_info": pretty_env_info,
        "transformers_version": transformers_version,
Baber Abbasi's avatar
Baber Abbasi committed
115
        "lm_eval_version": lm_eval_version,
116
117
118
        "upper_git_hash": upper_dir_commit,  # in case this repo is submodule
    }
    storage.update(added_info)
achervyakov's avatar
achervyakov committed
119
120
121
122


def add_tokenizer_info(storage: Dict[str, Any], lm):
    if getattr(lm, "tokenizer", False):
123
124
125
126
        try:
            tokenizer_info = {
                "tokenizer_pad_token": [
                    lm.tokenizer.pad_token,
127
                    str(lm.tokenizer.pad_token_id),
128
129
130
                ],
                "tokenizer_eos_token": [
                    lm.tokenizer.eos_token,
131
                    str(lm.tokenizer.eos_token_id),
132
133
134
                ],
                "tokenizer_bos_token": [
                    lm.tokenizer.bos_token,
135
                    str(lm.tokenizer.bos_token_id),
136
137
138
139
140
141
142
143
144
145
                ],
                "eot_token_id": getattr(lm, "eot_token_id", None),
                "max_length": getattr(lm, "max_length", None),
            }
            storage.update(tokenizer_info)
        except Exception as err:
            logger.debug(
                f"Logging detailed tokenizer info failed with {err}, skipping..."
            )
        # seems gguf and textsynth do not have tokenizer
achervyakov's avatar
achervyakov committed
146
147
148
149
    else:
        logger.debug(
            "LM does not have a 'tokenizer' attribute, not logging tokenizer metadata to results."
        )