Commit 649ca8fc authored by Baber's avatar Baber
Browse files

fix logging

parent 61520ad6
import logging
import os
__version__ = "0.4.9.1"
......
from lm_eval._cli.eval import Eval
from lm_eval.utils import setup_logging
def cli_evaluate() -> None:
"""Main CLI entry point with subcommand and legacy support."""
setup_logging()
parser = Eval()
args = parser.parse_args()
parser.execute(args)
......
......@@ -434,6 +434,7 @@ class Run(SubCommand):
evaluation_tracker.recreate_metadata_card()
# Print results
cfg.model_args.pop("trust_remote_code", None)
print(
f"{cfg.model} ({cfg.model_args}), gen_kwargs: ({cfg.gen_kwargs}), "
f"limit: {cfg.limit}, num_fewshot: {cfg.num_fewshot}, "
......
import json
import logging
import warnings
from argparse import Namespace
from dataclasses import asdict, dataclass, field
from pathlib import Path
......@@ -14,7 +13,7 @@ from lm_eval.utils import simple_parse_args_string
if TYPE_CHECKING:
from lm_eval.tasks import TaskManager
eval_logger = logging.getLogger(__name__)
DICT_KEYS = [
"wandb_args",
"wandb_config_args",
......@@ -274,7 +273,7 @@ class EvaluatorConfig:
def _validate_arguments(self) -> None:
"""Validate configuration arguments and cross-field constraints."""
if self.limit:
warnings.warn(
eval_logger.warning(
"--limit SHOULD ONLY BE USED FOR TESTING. "
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
......@@ -369,9 +368,6 @@ class EvaluatorConfig:
def _apply_trust_remote_code(self) -> None:
"""Apply trust_remote_code setting if enabled."""
if self.trust_remote_code:
eval_logger = logging.getLogger(__name__)
eval_logger.info("Setting HF_DATASETS_TRUST_REMOTE_CODE=true")
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
# because it's already been determined based on the prior env var before launching our
# script--`datasets` gets imported by lm_eval internally before these lines can update the env.
......
......@@ -31,11 +31,11 @@ from lm_eval.loggers import EvaluationTracker
from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.utils import (
get_logger,
handle_non_serializable,
hash_dict_images,
hash_string,
positional_deprecated,
setup_logging,
simple_parse_args_string,
wrap_text,
)
......@@ -149,7 +149,7 @@ def simple_evaluate(
Dictionary of results
"""
if verbosity is not None:
setup_logging(verbosity=verbosity)
get_logger(verbosity)
start_date = time.time()
if limit is not None and samples is not None:
......@@ -372,8 +372,6 @@ def simple_evaluate(
verbosity=verbosity,
confirm_run_unsafe_code=confirm_run_unsafe_code,
)
if verbosity is not None:
setup_logging(verbosity=verbosity)
if lm.rank == 0:
if isinstance(model, str):
......
......@@ -30,7 +30,7 @@ class TaskManager:
metadata: Optional[dict] = None,
) -> None:
if verbosity is not None:
utils.setup_logging(verbosity)
utils.get_logger(verbosity)
self.include_path = include_path
self.metadata = metadata
self._task_index = self.initialize_tasks(
......
......@@ -2,6 +2,7 @@ from __future__ import annotations
import collections
import fnmatch
import functools
import hashlib
import importlib.util
import inspect
......@@ -14,7 +15,7 @@ from dataclasses import asdict, is_dataclass
from functools import lru_cache, partial, wraps
from itertools import islice
from pathlib import Path
from typing import Any, Callable
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
import numpy as np
import yaml
......@@ -27,8 +28,6 @@ HIGHER_IS_BETTER_SYMBOLS = {
True: "↑",
False: "↓",
}
def wrap_text(string: str, width: int = 140, **kwargs) -> Optional[str]:
"""
Wraps the given string to the specified width.
......@@ -46,8 +45,76 @@ def wrap_text(string: str, width: int = 140, **kwargs) -> Optional[str]:
)
def setup_logging(verbosity=logging.INFO):
# Configure the root logger
def get_logger(level: Optional[str] = None) -> logging.Logger:
"""
Get a logger with a stream handler that captures all lm_eval logs.
Args:
level (Optional[str]): The logging level.
Example:
>>> logger = get_logger("INFO")
>>> logger.info("Log this")
INFO:lm_eval:Log this!
Returns:
logging.Logger: The logger.
"""
logger = logging.getLogger("lm_eval")
if not logger.hasHandlers():
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.INFO)
if level is not None:
level = getattr(logging, level.upper())
logger.setLevel(level)
return logger
def setup_logging(verbosity=logging.INFO, suppress_third_party=True):
"""
Configure logging for the lm_eval CLI application.
WARNING: This function is intended for CLI use only. Library users should
use get_logger() instead to avoid interfering with their application's
logging configuration.
Args:
verbosity: Log level (int) or string name. Can be overridden by LOGLEVEL env var.
suppress_third_party: Whether to suppress verbose third-party library logs.
Returns:
logging.Logger: The configured lm_eval logger instance.
"""
# Validate verbosity parameter
if isinstance(verbosity, str):
level_map = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
verbosity = level_map.get(verbosity.upper(), logging.INFO)
elif not isinstance(verbosity, int):
verbosity = logging.INFO
# Get log level from environment or use default
if log_level_env := os.environ.get("LOGLEVEL", None):
level_map = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
log_level = level_map.get(log_level_env.upper(), verbosity)
else:
log_level = verbosity
# Get the lm_eval logger directly
logger = logging.getLogger("lm_eval")
# Configure custom formatter
class CustomFormatter(logging.Formatter):
def format(self, record):
if record.name.startswith("lm_eval."):
......@@ -59,32 +126,27 @@ def setup_logging(verbosity=logging.INFO):
datefmt="%Y-%m-%d:%H:%M:%S",
)
log_level = os.environ.get("LOGLEVEL", verbosity) or verbosity
level_map = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
log_level = level_map.get(str(log_level).upper(), logging.INFO)
if not logging.root.handlers:
# Check if handler already exists to prevent duplicates
has_stream_handler = any(
isinstance(h, logging.StreamHandler) for h in logger.handlers
)
if not has_stream_handler:
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
# For CLI use, we disable propagation to avoid duplicate messages
logger.propagate = False
root_logger = logging.getLogger()
root_logger.addHandler(handler)
root_logger.setLevel(log_level)
# Set the logger level
logger.setLevel(log_level)
if log_level == logging.DEBUG:
third_party_loggers = ["urllib3", "filelock", "fsspec"]
for logger_name in third_party_loggers:
logging.getLogger(logger_name).setLevel(logging.INFO)
else:
logging.getLogger().setLevel(log_level)
# Optionally suppress verbose third-party library logs
if suppress_third_party and log_level == logging.DEBUG:
third_party_loggers = ["urllib3", "filelock", "fsspec"]
for logger_name in third_party_loggers:
logging.getLogger(logger_name).setLevel(logging.INFO)
return logger
def hash_string(string: str) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment