Unverified Commit 59cf408a authored by KonradSzafer's avatar KonradSzafer Committed by GitHub
Browse files

evaluation tracker implementation (#1766)

* evaluation tracker implementation

* OVModelForCausalLM test fix

* typo fix

* moved methods args

* multiple args in one flag

* loggers moved to dedicated dir

* improved filename sanitization
parent e6394715
...@@ -301,10 +301,23 @@ lm_eval --model hf \ ...@@ -301,10 +301,23 @@ lm_eval --model hf \
We support wildcards in task names, for example you can run all of the machine-translated lambada tasks via `--task lambada_openai_mt_*`. We support wildcards in task names, for example you can run all of the machine-translated lambada tasks via `--task lambada_openai_mt_*`.
## Saving Results
To save evaluation results provide an `--output_path`. We also support logging model responses with the `--log_samples` flag for post-hoc analysis. To save evaluation results provide an `--output_path`. We also support logging model responses with the `--log_samples` flag for post-hoc analysis.
Additionally, one can provide a directory with `--use_cache` to cache the results of prior runs. This allows you to avoid repeated execution of the same (model, task) pairs for re-scoring. Additionally, one can provide a directory with `--use_cache` to cache the results of prior runs. This allows you to avoid repeated execution of the same (model, task) pairs for re-scoring.
To push results and samples to the Hugging Face Hub, first ensure an access token with write access is set in the `HF_TOKEN` environment variable. Then, use the --hf_hub_log_args flag to specify the organization, repository name, repository visibility, and whether to push results and samples to the Hub. For example:
```bash
lm_eval --model hf \
--model_args pretrained=model-name-or-path,autogptq=model.safetensors,gptq_use_triton=True \
--tasks hellaswag \
--log_samples \
--output_path results \
----hf_hub_log_args hub_results_org=EleutherAI,hub_repo_name=lm-eval-results,push_results_to_hub=True,push_samples_to_hub=True,public_repo=False \
```
For a full list of supported arguments, check out the [interface](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md) guide in our documentation! For a full list of supported arguments, check out the [interface](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md) guide in our documentation!
## Visualizing Results ## Visualizing Results
......
...@@ -2,31 +2,16 @@ import argparse ...@@ -2,31 +2,16 @@ import argparse
import json import json
import logging import logging
import os import os
import re
import sys import sys
from argparse import Namespace
from functools import partial from functools import partial
from pathlib import Path
from typing import Union from typing import Union
import numpy as np
from lm_eval import evaluator, utils from lm_eval import evaluator, utils
from lm_eval.evaluator import request_caching_arg_to_dict from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.logging_utils import WandbLogger from lm_eval.logging import EvaluationTracker, WandbLogger
from lm_eval.tasks import TaskManager from lm_eval.tasks import TaskManager
from lm_eval.utils import make_table, simple_parse_args_string from lm_eval.utils import handle_non_serializable, make_table, simple_parse_args_string
DEFAULT_RESULTS_FILE = "results.json"
def _handle_non_serializable(o):
if isinstance(o, np.int64) or isinstance(o, np.int32):
return int(o)
elif isinstance(o, set):
return list(o)
else:
return str(o)
def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","): def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
...@@ -203,6 +188,12 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -203,6 +188,12 @@ def setup_parser() -> argparse.ArgumentParser:
default="", default="",
help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval", help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
) )
parser.add_argument(
"--hf_hub_log_args",
type=str,
default="",
help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
)
parser.add_argument( parser.add_argument(
"--predict_only", "--predict_only",
"-x", "-x",
...@@ -228,7 +219,6 @@ def setup_parser() -> argparse.ArgumentParser: ...@@ -228,7 +219,6 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true", action="store_true",
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub", help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
) )
return parser return parser
...@@ -251,6 +241,15 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -251,6 +241,15 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
eval_logger.info(f"Verbosity set to {args.verbosity}") eval_logger.info(f"Verbosity set to {args.verbosity}")
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
# update the evaluation tracker args with the output path and the HF token
args.hf_hub_log_args = f"output_path={args.output_path},token={os.environ.get('HF_TOKEN')},{args.hf_hub_log_args}"
evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args)
evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
evaluation_tracker.general_config_tracker.log_experiment_args(
model_source=args.model,
model_args=args.model_args,
)
if args.predict_only: if args.predict_only:
args.log_samples = True args.log_samples = True
if (args.log_samples or args.predict_only) and not args.output_path: if (args.log_samples or args.predict_only) and not args.output_path:
...@@ -262,6 +261,19 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -262,6 +261,19 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
eval_logger.info(f"Including path: {args.include_path}") eval_logger.info(f"Including path: {args.include_path}")
task_manager = TaskManager(args.verbosity, include_path=args.include_path) task_manager = TaskManager(args.verbosity, include_path=args.include_path)
evaluation_tracker_args = Namespace(**evaluation_tracker_args)
if (
evaluation_tracker_args.push_results_to_hub
or evaluation_tracker_args.push_samples_to_hub
) and not evaluation_tracker_args.hub_results_org:
raise ValueError(
"If push_results_to_hub or push_samples_to_hub is set, results_org must be specified."
)
if evaluation_tracker_args.push_samples_to_hub and not args.log_samples:
eval_logger.warning(
"Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
)
if args.limit: if args.limit:
eval_logger.warning( eval_logger.warning(
" --limit SHOULD ONLY BE USED FOR TESTING." " --limit SHOULD ONLY BE USED FOR TESTING."
...@@ -306,24 +318,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -306,24 +318,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues." f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues."
) )
if args.output_path:
path = Path(args.output_path)
# check if file or 'dir/results.json' exists
if path.is_file():
raise FileExistsError(f"File already exists at {path}")
output_path_file = path.joinpath(DEFAULT_RESULTS_FILE)
if output_path_file.is_file():
eval_logger.warning(
f"File {output_path_file} already exists. Results will be overwritten."
)
# if path json then get parent dir
elif path.suffix in (".json", ".jsonl"):
output_path_file = path
path.parent.mkdir(parents=True, exist_ok=True)
path = path.parent
else:
path.mkdir(parents=True, exist_ok=True)
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
if args.trust_remote_code: if args.trust_remote_code:
os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = str(args.trust_remote_code) os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = str(args.trust_remote_code)
...@@ -365,7 +359,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -365,7 +359,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if args.log_samples: if args.log_samples:
samples = results.pop("samples") samples = results.pop("samples")
dumped = json.dumps( dumped = json.dumps(
results, indent=2, default=_handle_non_serializable, ensure_ascii=False results, indent=2, default=handle_non_serializable, ensure_ascii=False
) )
if args.show_config: if args.show_config:
print(dumped) print(dumped)
...@@ -382,23 +376,13 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -382,23 +376,13 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
except Exception as e: except Exception as e:
eval_logger.info(f"Logging to Weights and Biases failed due to {e}") eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
if args.output_path: evaluation_tracker.save_results_aggregated(results=results, samples=samples)
output_path_file.open("w", encoding="utf-8").write(dumped)
if args.log_samples: if args.log_samples:
for task_name, config in results["configs"].items(): for task_name, config in results["configs"].items():
output_name = "{}_{}".format( evaluation_tracker.save_results_samples(
re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", args.model_args), task_name=task_name, samples=samples[task_name]
task_name, )
)
filename = path.joinpath(f"{output_name}.jsonl")
samples_dumped = json.dumps(
samples[task_name],
indent=2,
default=_handle_non_serializable,
ensure_ascii=False,
)
filename.write_text(samples_dumped, encoding="utf-8")
print( print(
f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, " f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
......
import itertools import itertools
import json
import logging import logging
import random import random
import time import time
...@@ -20,9 +21,15 @@ from lm_eval.evaluator_utils import ( ...@@ -20,9 +21,15 @@ from lm_eval.evaluator_utils import (
print_writeout, print_writeout,
run_task_tests, run_task_tests,
) )
from lm_eval.logging_utils import add_env_info, get_git_commit_hash from lm_eval.logging.utils import add_env_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.utils import eval_logger, positional_deprecated, simple_parse_args_string from lm_eval.utils import (
eval_logger,
handle_non_serializable,
hash_string,
positional_deprecated,
simple_parse_args_string,
)
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -272,6 +279,13 @@ def simple_evaluate( ...@@ -272,6 +279,13 @@ def simple_evaluate(
results["config"] = { results["config"] = {
"model": model_name, "model": model_name,
"model_args": model_args, "model_args": model_args,
}
# add more detailed model info if available
if isinstance(lm, lm_eval.models.huggingface.HFLM):
results["config"].update(lm.get_model_info())
# add info about execution
results["config"].update(
{
"batch_size": batch_size, "batch_size": batch_size,
"batch_sizes": ( "batch_sizes": (
list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else [] list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
...@@ -282,6 +296,7 @@ def simple_evaluate( ...@@ -282,6 +296,7 @@ def simple_evaluate(
"bootstrap_iters": bootstrap_iters, "bootstrap_iters": bootstrap_iters,
"gen_kwargs": gen_kwargs, "gen_kwargs": gen_kwargs,
} }
)
results["git_hash"] = get_git_commit_hash() results["git_hash"] = get_git_commit_hash()
results["date"] = start_date results["date"] = start_date
add_env_info(results) # additional environment info to results add_env_info(results) # additional environment info to results
...@@ -349,7 +364,6 @@ def evaluate( ...@@ -349,7 +364,6 @@ def evaluate(
eval_logger.debug( eval_logger.debug(
f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}" f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
) )
if write_out: if write_out:
print_writeout(task) print_writeout(task)
# aggregate Instances by LM method requested to get output. # aggregate Instances by LM method requested to get output.
...@@ -435,6 +449,16 @@ def evaluate( ...@@ -435,6 +449,16 @@ def evaluate(
"filtered_resps": [ "filtered_resps": [
req.filtered_resps[filter_key] for req in requests req.filtered_resps[filter_key] for req in requests
], ],
"doc_hash": hash_string(
json.dumps(
requests[0].doc,
indent=2,
default=handle_non_serializable,
ensure_ascii=False,
)
),
"prompt_hash": hash_string(requests[0].arguments[0]),
"target_hash": hash_string(str(target)),
} }
example.update(metrics) example.update(metrics)
task_output.logged_samples.append(example) task_output.logged_samples.append(example)
...@@ -565,6 +589,13 @@ def evaluate( ...@@ -565,6 +589,13 @@ def evaluate(
"configs": dict(sorted(configs.items())), "configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())), "versions": dict(sorted(versions.items())),
"n-shot": dict(sorted(num_fewshot.items())), "n-shot": dict(sorted(num_fewshot.items())),
"n-samples": {
task_output.task_name: {
"original": len(task_output.task.eval_docs),
"effective": min(limit, len(task_output.task.eval_docs)),
}
for task_output in eval_tasks
},
} }
if log_samples: if log_samples:
results_dict["samples"] = dict(samples) results_dict["samples"] = dict(samples)
......
from .evaluation_tracker import EvaluationTracker
from .wandb_logger import WandbLogger
import json
import re
import time
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from huggingface_hub import HfApi
from lm_eval.utils import (
eval_logger,
handle_non_serializable,
hash_string,
)
@dataclass(init=False)
class GeneralConfigTracker:
"""
Tracker for the evaluation parameters.
Attributes:
model_source (str): Source of the model (e.g. Hugging Face, GGUF, etc.)
model_name (str): Name of the model.
model_name_sanitized (str): Sanitized model name for directory creation.
start_time (float): Start time of the experiment. Logged at class init.
end_time (float): Start time of the experiment. Logged when calling [`GeneralConfigTracker.log_end_time`]
total_evaluation_time_seconds (str): Inferred total evaluation time in seconds (from the start and end times).
"""
model_source: str = None
model_name: str = None
model_name_sanitized: str = None
start_time: float = None
end_time: float = None
total_evaluation_time_seconds: str = None
def __init__(self) -> None:
"""Starts the evaluation timer."""
self.start_time = time.perf_counter()
@staticmethod
def _get_model_name(model_args: str) -> str:
"""Extracts the model name from the model arguments."""
def extract_model_name(model_args: str, key: str) -> str:
"""Extracts the model name from the model arguments using a key."""
args_after_key = model_args.split(key)[1]
return args_after_key.split(",")[0]
# order does matter, e.g. peft and delta are provided together with pretrained
prefixes = ["peft=", "delta=", "pretrained=", "model=", "path=", "engine="]
for prefix in prefixes:
if prefix in model_args:
return extract_model_name(model_args, prefix)
return ""
def log_experiment_args(
self,
model_source: str,
model_args: str,
) -> None:
"""Logs model parameters and job ID."""
self.model_source = model_source
self.model_name = GeneralConfigTracker._get_model_name(model_args)
self.model_name_sanitized = re.sub(
r"[\"<>:/\|\\?\*\[\]]+", "__", self.model_name
)
def log_end_time(self) -> None:
"""Logs the end time of the evaluation and calculates the total evaluation time."""
self.end_time = time.perf_counter()
self.total_evaluation_time_seconds = str(self.end_time - self.start_time)
class EvaluationTracker:
"""
Keeps track and saves relevant information of the evaluation process.
Compiles the data from trackers and writes it to files, which can be published to the Hugging Face hub if requested.
"""
def __init__(
self,
output_path: str = "",
hub_results_org: str = "",
hub_repo_name: str = "",
push_results_to_hub: bool = False,
push_samples_to_hub: bool = False,
public_repo: bool = False,
token: str = "",
) -> None:
"""
Creates all the necessary loggers for evaluation tracking.
Args:
output_path (str): Path to save the results. If not provided, the results won't be saved.
hub_results_org (str): The Hugging Face organisation to push the results to. If not provided, the results won't be pushed.
hub_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will be pushed to `lm-eval-results`.
push_results_to_hub (bool): Whether to push the results to the Hugging Face hub.
push_samples_to_hub (bool): Whether to push the samples to the Hugging Face hub.
public_repo (bool): Whether to push the results to a public or private repository.
token (str): Token to use when pushing to the Hugging Face hub. This token should have write access to `hub_results_org`.
"""
self.general_config_tracker = GeneralConfigTracker()
self.output_path = output_path
self.hub_results_org = hub_results_org
hub_repo_name = hub_repo_name if hub_repo_name else "lm-eval-results"
self.hub_results_repo = f"{hub_results_org}/{hub_repo_name}"
self.hub_results_repo_private = f"{hub_results_org}/{hub_repo_name}-private"
self.push_results_to_hub = push_results_to_hub
self.push_samples_to_hub = push_samples_to_hub
self.public_repo = public_repo
self.api = HfApi(token=token) if token else None
def save_results_aggregated(
self,
results: dict,
samples: dict,
) -> None:
"""
Saves the aggregated results and samples to the output path and pushes them to the Hugging Face hub if requested.
Args:
results (dict): The aggregated results to save.
samples (dict): The samples results to save.
"""
self.general_config_tracker.log_end_time()
if self.output_path:
try:
eval_logger.info("Saving results aggregated")
# calculate cumulative hash for each task
task_hashes = {}
for task_name, task_samples in samples.items():
sample_hashes = [
s["doc_hash"] + s["prompt_hash"] + s["target_hash"]
for s in task_samples
]
task_hashes[task_name] = hash_string("".join(sample_hashes))
# update initial results dict
results.update({"task_hashes": task_hashes})
results.update(asdict(self.general_config_tracker))
dumped = json.dumps(
results,
indent=2,
default=handle_non_serializable,
ensure_ascii=False,
)
path = Path(self.output_path if self.output_path else Path.cwd())
path = path.joinpath(self.general_config_tracker.model_name_sanitized)
path.mkdir(parents=True, exist_ok=True)
self.date_id = datetime.now().isoformat().replace(":", "-")
file_results_aggregated = path.joinpath(f"results_{self.date_id}.json")
file_results_aggregated.open("w", encoding="utf-8").write(dumped)
if self.api and self.push_results_to_hub:
self.api.create_repo(
repo_id=self.hub_results_repo
if self.public_repo
else self.hub_results_repo_private,
repo_type="dataset",
private=not self.public_repo,
exist_ok=True,
)
self.api.upload_folder(
repo_id=self.hub_results_repo
if self.public_repo
else self.hub_results_repo_private,
folder_path=str(path),
path_in_repo=self.general_config_tracker.model_name_sanitized,
repo_type="dataset",
commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}",
)
except Exception as e:
eval_logger.warning("Could not save results aggregated")
eval_logger.info(repr(e))
else:
eval_logger.info(
"Output path not provided, skipping saving results aggregated"
)
def save_results_samples(
self,
task_name: str,
samples: dict,
) -> None:
"""
Saves the samples results to the output path and pushes them to the Hugging Face hub if requested.
Args:
task_name (str): The task name to save the samples for.
samples (dict): The samples results to save.
"""
if self.output_path:
try:
eval_logger.info("Saving samples results")
samples_dumped = json.dumps(
samples,
indent=2,
default=handle_non_serializable,
ensure_ascii=False,
)
path = Path(self.output_path if self.output_path else Path.cwd())
path = path.joinpath(self.general_config_tracker.model_name_sanitized)
path.mkdir(parents=True, exist_ok=True)
file_results_samples = path.joinpath(
f"samples_{task_name}_{self.date_id}.json"
)
file_results_samples.write_text(samples_dumped, encoding="utf-8")
if self.api and self.push_samples_to_hub:
self.api.create_repo(
self.hub_results_repo
if self.public_repo
else self.hub_results_repo_private,
repo_type="dataset",
private=not self.public_repo,
exist_ok=True,
)
self.api.upload_folder(
repo_id=self.hub_results_repo
if self.public_repo
else self.hub_results_repo_private,
folder_path=str(path),
path_in_repo=self.general_config_tracker.model_name_sanitized,
repo_type="dataset",
commit_message=f"Adding samples results for {task_name} to {self.general_config_tracker.model_name}",
)
except Exception as e:
eval_logger.warning("Could not save sample results")
eval_logger.info(repr(e))
else:
eval_logger.info("Output path not provided, skipping saving sample results")
import logging
import os
import re
import subprocess
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
from torch.utils.collect_env import get_pretty_env_info
from transformers import __version__ as trans_version
logger = logging.getLogger(__name__)
def remove_none_pattern(input_string: str) -> Tuple[str, bool]:
"""Remove the ',none' substring from the input_string if it exists at the end.
Args:
input_string (str): The input string from which to remove the ',none' substring.
Returns:
Tuple[str, bool]: A tuple containing the modified input_string with the ',none' substring removed
and a boolean indicating whether the modification was made (True) or not (False).
"""
# Define the pattern to match ',none' at the end of the string
pattern = re.compile(r",none$")
# Use sub() to replace ',none' with an empty string
result = re.sub(pattern, "", input_string)
# check if the input_string changed
removed = result != input_string
return result, removed
def _handle_non_serializable(o: Any) -> Union[int, str, list]:
"""Handle non-serializable objects by converting them to serializable types.
Args:
o (Any): The object to be handled.
Returns:
Union[int, str, list]: The converted object. If the object is of type np.int64 or np.int32,
it will be converted to int. If the object is of type set, it will be converted
to a list. Otherwise, it will be converted to str.
"""
if isinstance(o, np.int64) or isinstance(o, np.int32):
return int(o)
elif isinstance(o, set):
return list(o)
else:
return str(o)
def get_commit_from_path(repo_path: Union[Path, str]) -> Optional[str]:
try:
git_folder = Path(repo_path, ".git")
if git_folder.is_file():
git_folder = Path(
git_folder.parent,
git_folder.read_text(encoding="utf-8").split("\n")[0].split(" ")[-1],
)
if Path(git_folder, "HEAD").exists():
head_name = (
Path(git_folder, "HEAD")
.read_text(encoding="utf-8")
.split("\n")[0]
.split(" ")[-1]
)
head_ref = Path(git_folder, head_name)
git_hash = head_ref.read_text(encoding="utf-8").replace("\n", "")
else:
git_hash = None
except Exception as err:
logger.debug(
f"Failed to retrieve a Git commit hash from path: {str(repo_path)}. Error: {err}"
)
return None
return git_hash
def get_git_commit_hash():
"""
Gets the git commit hash of your current repo (if it exists).
Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42
"""
try:
git_hash = subprocess.check_output(["git", "describe", "--always"]).strip()
git_hash = git_hash.decode()
except (subprocess.CalledProcessError, FileNotFoundError):
# FileNotFoundError occurs when git not installed on system
git_hash = get_commit_from_path(os.getcwd()) # git hash of repo if exists
return git_hash
def add_env_info(storage: Dict[str, Any]):
try:
pretty_env_info = get_pretty_env_info()
except Exception as err:
pretty_env_info = str(err)
transformers_version = trans_version
upper_dir_commit = get_commit_from_path(
Path(os.getcwd(), "..")
) # git hash of upper repo if exists
added_info = {
"pretty_env_info": pretty_env_info,
"transformers_version": transformers_version,
"upper_git_hash": upper_dir_commit, # in case this repo is submodule
}
storage.update(added_info)
import copy import copy
import json import json
import logging import logging
import os from typing import Any, Dict, List, Literal, Tuple
import re
import subprocess
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from packaging.version import Version from packaging.version import Version
from torch.utils.collect_env import get_pretty_env_info
from transformers import __version__ as trans_version
from lm_eval.logging.utils import _handle_non_serializable, remove_none_pattern
logger = logging.getLogger(__name__)
def remove_none_pattern(input_string: str) -> Tuple[str, bool]:
"""Remove the ',none' substring from the input_string if it exists at the end.
Args:
input_string (str): The input string from which to remove the ',none' substring.
Returns:
Tuple[str, bool]: A tuple containing the modified input_string with the ',none' substring removed
and a boolean indicating whether the modification was made (True) or not (False).
"""
# Define the pattern to match ',none' at the end of the string
pattern = re.compile(r",none$")
# Use sub() to replace ',none' with an empty string
result = re.sub(pattern, "", input_string)
# check if the input_string changed logger = logging.getLogger(__name__)
removed = result != input_string
return result, removed
def _handle_non_serializable(o: Any) -> Union[int, str, list]:
"""Handle non-serializable objects by converting them to serializable types.
Args:
o (Any): The object to be handled.
Returns:
Union[int, str, list]: The converted object. If the object is of type np.int64 or np.int32,
it will be converted to int. If the object is of type set, it will be converted
to a list. Otherwise, it will be converted to str.
"""
if isinstance(o, np.int64) or isinstance(o, np.int32):
return int(o)
elif isinstance(o, set):
return list(o)
else:
return str(o)
def get_wandb_printer() -> Literal["Printer"]: def get_wandb_printer() -> Literal["Printer"]:
...@@ -395,61 +350,3 @@ class WandbLogger: ...@@ -395,61 +350,3 @@ class WandbLogger:
self._log_samples_as_artifact(eval_preds, task_name) self._log_samples_as_artifact(eval_preds, task_name)
self.run.log({f"{group}_eval_results": grouped_df}) self.run.log({f"{group}_eval_results": grouped_df})
def get_commit_from_path(repo_path: Union[Path, str]) -> Optional[str]:
try:
git_folder = Path(repo_path, ".git")
if git_folder.is_file():
git_folder = Path(
git_folder.parent,
git_folder.read_text(encoding="utf-8").split("\n")[0].split(" ")[-1],
)
if Path(git_folder, "HEAD").exists():
head_name = (
Path(git_folder, "HEAD")
.read_text(encoding="utf-8")
.split("\n")[0]
.split(" ")[-1]
)
head_ref = Path(git_folder, head_name)
git_hash = head_ref.read_text(encoding="utf-8").replace("\n", "")
else:
git_hash = None
except Exception as err:
logger.debug(
f"Failed to retrieve a Git commit hash from path: {str(repo_path)}. Error: {err}"
)
return None
return git_hash
def get_git_commit_hash():
"""
Gets the git commit hash of your current repo (if it exists).
Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42
"""
try:
git_hash = subprocess.check_output(["git", "describe", "--always"]).strip()
git_hash = git_hash.decode()
except (subprocess.CalledProcessError, FileNotFoundError):
# FileNotFoundError occurs when git not installed on system
git_hash = get_commit_from_path(os.getcwd()) # git hash of repo if exists
return git_hash
def add_env_info(storage: Dict[str, Any]):
try:
pretty_env_info = get_pretty_env_info()
except Exception as err:
pretty_env_info = str(err)
transformers_version = trans_version
upper_dir_commit = get_commit_from_path(
Path(os.getcwd(), "..")
) # git hash of upper repo if exists
added_info = {
"pretty_env_info": pretty_env_info,
"transformers_version": transformers_version,
"upper_git_hash": upper_dir_commit, # in case this repo is submodule
}
storage.update(added_info)
...@@ -13,6 +13,7 @@ from accelerate import ( ...@@ -13,6 +13,7 @@ from accelerate import (
InitProcessGroupKwargs, InitProcessGroupKwargs,
find_executable_batch_size, find_executable_batch_size,
) )
from huggingface_hub import HfApi
from packaging import version from packaging import version
from peft import PeftModel from peft import PeftModel
from peft import __version__ as PEFT_VERSION from peft import __version__ as PEFT_VERSION
...@@ -278,7 +279,10 @@ class HFLM(TemplateLM): ...@@ -278,7 +279,10 @@ class HFLM(TemplateLM):
) )
self._max_length = max_length self._max_length = max_length
self.pretrained = pretrained
self.delta = delta
self.peft = peft
self.revision = revision
self.batch_schedule = 1 self.batch_schedule = 1
self.batch_sizes = {} self.batch_sizes = {}
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
...@@ -1272,3 +1276,44 @@ class HFLM(TemplateLM): ...@@ -1272,3 +1276,44 @@ class HFLM(TemplateLM):
pbar.close() pbar.close()
return res return res
def get_model_info(self) -> dict:
"""
Method to get Hugging Face model information for experiment reproducibility.
"""
def get_model_num_params(model) -> int:
if hasattr(model, "num_parameters"):
return model.num_parameters()
if hasattr(model, "parameters"):
return sum(p.numel() for p in model.parameters())
else:
return -1
def get_model_dtype(model) -> str:
if hasattr(model, "dtype"):
return model.dtype
else:
return ""
def get_model_sha(pretrained: str, revision: str) -> str:
try:
model_info = HfApi().model_info(repo_id=pretrained, revision=revision)
return model_info.sha
except Exception as e:
eval_logger.warn(
f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}"
)
return ""
model_info = {
"model_num_parameters": get_model_num_params(self._model),
"model_dtype": get_model_dtype(self._model),
"model_revision": self.revision,
"model_sha": get_model_sha(self.pretrained, self.revision),
}
if self.peft:
model_info["peft_sha"] = get_model_sha(self.peft, self.revision)
if self.delta:
model_info["delta_sha"] = get_model_sha(self.delta, self.revision)
return model_info
import collections import collections
import fnmatch import fnmatch
import functools import functools
import hashlib
import importlib.util import importlib.util
import inspect import inspect
import json
import logging import logging
import os import os
import re import re
from dataclasses import asdict, is_dataclass
from itertools import islice from itertools import islice
from typing import Any, Callable, List from typing import Any, Callable, List
...@@ -24,6 +27,10 @@ eval_logger = logging.getLogger("lm-eval") ...@@ -24,6 +27,10 @@ eval_logger = logging.getLogger("lm-eval")
SPACING = " " * 47 SPACING = " " * 47
def hash_string(string: str) -> str:
return hashlib.sha256(string.encode("utf-8")).hexdigest()
def escaped_split(text, sep_char, maxsplit=-1): def escaped_split(text, sep_char, maxsplit=-1):
"""Split text into a list on occurrences of the given separation """Split text into a list on occurrences of the given separation
character `sep_char`. The separation character may be escaped by a character `sep_char`. The separation character may be escaped by a
...@@ -60,6 +67,15 @@ def handle_arg_string(arg): ...@@ -60,6 +67,15 @@ def handle_arg_string(arg):
return arg return arg
def handle_non_serializable(o):
if isinstance(o, np.int64) or isinstance(o, np.int32):
return int(o)
elif isinstance(o, set):
return list(o)
else:
return str(o)
def simple_parse_args_string(args_string): def simple_parse_args_string(args_string):
""" """
Parses something like Parses something like
...@@ -166,6 +182,18 @@ def make_disjoint_window(pair): ...@@ -166,6 +182,18 @@ def make_disjoint_window(pair):
return a[: len(a) - (len(b) - 1)], b return a[: len(a) - (len(b) - 1)], b
class EnhancedJSONEncoder(json.JSONEncoder):
"""
Provides a proper json encoding for the loggers and trackers json dumps.
Notably manages the json encoding of dataclasses.
"""
def default(self, o):
if is_dataclass(o):
return asdict(o)
return super().default(o)
class Reorderer: class Reorderer:
def __init__(self, arr: List[Any], fn: Callable) -> None: def __init__(self, arr: List[Any], fn: Callable) -> None:
"""Reorder an array according to some function """Reorder an array according to some function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment