Commit dbe4c391 authored by Baber's avatar Baber
Browse files

improve logging

parent 442ce51a
import logging
import os
__version__ = "0.4.9" __version__ = "0.4.9"
......
...@@ -13,7 +13,7 @@ from lm_eval.utils import simple_parse_args_string ...@@ -13,7 +13,7 @@ from lm_eval.utils import simple_parse_args_string
if TYPE_CHECKING: if TYPE_CHECKING:
from lm_eval.tasks import TaskManager from lm_eval.tasks import TaskManager
eval_logger = logging.getLogger(__name__)
DICT_KEYS = [ DICT_KEYS = [
"wandb_args", "wandb_args",
"wandb_config_args", "wandb_config_args",
...@@ -273,7 +273,7 @@ class EvaluatorConfig: ...@@ -273,7 +273,7 @@ class EvaluatorConfig:
def _validate_arguments(self) -> None: def _validate_arguments(self) -> None:
"""Validate configuration arguments and cross-field constraints.""" """Validate configuration arguments and cross-field constraints."""
if self.limit: if self.limit:
logging.warning( eval_logger.warning(
"--limit SHOULD ONLY BE USED FOR TESTING. " "--limit SHOULD ONLY BE USED FOR TESTING. "
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT." "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
) )
...@@ -368,9 +368,6 @@ class EvaluatorConfig: ...@@ -368,9 +368,6 @@ class EvaluatorConfig:
def _apply_trust_remote_code(self) -> None: def _apply_trust_remote_code(self) -> None:
"""Apply trust_remote_code setting if enabled.""" """Apply trust_remote_code setting if enabled."""
if self.trust_remote_code: if self.trust_remote_code:
eval_logger = logging.getLogger(__name__)
eval_logger.info("Setting HF_DATASETS_TRUST_REMOTE_CODE=true")
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally, # HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
# because it's already been determined based on the prior env var before launching our # because it's already been determined based on the prior env var before launching our
# script--`datasets` gets imported by lm_eval internally before these lines can update the env. # script--`datasets` gets imported by lm_eval internally before these lines can update the env.
......
...@@ -28,10 +28,10 @@ from lm_eval.loggers import EvaluationTracker ...@@ -28,10 +28,10 @@ from lm_eval.loggers import EvaluationTracker
from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.utils import ( from lm_eval.utils import (
get_logger,
handle_non_serializable, handle_non_serializable,
hash_string, hash_string,
positional_deprecated, positional_deprecated,
setup_logging,
simple_parse_args_string, simple_parse_args_string,
) )
...@@ -145,7 +145,7 @@ def simple_evaluate( ...@@ -145,7 +145,7 @@ def simple_evaluate(
Dictionary of results Dictionary of results
""" """
if verbosity is not None: if verbosity is not None:
setup_logging(verbosity=verbosity) get_logger(verbosity)
start_date = time.time() start_date = time.time()
if limit is not None and samples is not None: if limit is not None and samples is not None:
...@@ -355,8 +355,6 @@ def simple_evaluate( ...@@ -355,8 +355,6 @@ def simple_evaluate(
verbosity=verbosity, verbosity=verbosity,
confirm_run_unsafe_code=confirm_run_unsafe_code, confirm_run_unsafe_code=confirm_run_unsafe_code,
) )
if verbosity is not None:
setup_logging(verbosity=verbosity)
if lm.rank == 0: if lm.rank == 0:
if isinstance(model, str): if isinstance(model, str):
......
...@@ -30,7 +30,7 @@ class TaskManager: ...@@ -30,7 +30,7 @@ class TaskManager:
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
) -> None: ) -> None:
if verbosity is not None: if verbosity is not None:
utils.setup_logging(verbosity) utils.get_logger(verbosity)
self.include_path = include_path self.include_path = include_path
self.metadata = metadata self.metadata = metadata
self._task_index = self.initialize_tasks( self._task_index = self.initialize_tasks(
......
...@@ -26,8 +26,75 @@ HIGHER_IS_BETTER_SYMBOLS = { ...@@ -26,8 +26,75 @@ HIGHER_IS_BETTER_SYMBOLS = {
} }
def setup_logging(verbosity=logging.INFO): def get_logger(level: Optional[str] = None) -> logging.Logger:
# Configure the root logger """
Get a logger with a stream handler that captures all lm_eval logs.
Args:
level (Optional[str]): The logging level.
Example:
>>> logger = get_logger("INFO")
>>> logger.info("Log this")
INFO:lm_eval:Log this!
Returns:
logging.Logger: The logger.
"""
logger = logging.getLogger("lm_eval")
if not logger.hasHandlers():
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.INFO)
if level is not None:
level = getattr(logging, level.upper())
logger.setLevel(level)
return logger
def setup_logging(verbosity=logging.INFO, suppress_third_party=True):
"""
Configure logging for the lm_eval CLI application.
WARNING: This function is intended for CLI use only. Library users should
use get_logger() instead to avoid interfering with their application's
logging configuration.
Args:
verbosity: Log level (int) or string name. Can be overridden by LOGLEVEL env var.
suppress_third_party: Whether to suppress verbose third-party library logs.
Returns:
logging.Logger: The configured lm_eval logger instance.
"""
# Validate verbosity parameter
if isinstance(verbosity, str):
level_map = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
verbosity = level_map.get(verbosity.upper(), logging.INFO)
elif not isinstance(verbosity, int):
verbosity = logging.INFO
# Get log level from environment or use default
if log_level_env := os.environ.get("LOGLEVEL", None):
level_map = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
log_level = level_map.get(log_level_env.upper(), verbosity)
else:
log_level = verbosity
# Get the lm_eval logger directly
logger = logging.getLogger("lm_eval")
# Configure custom formatter
class CustomFormatter(logging.Formatter): class CustomFormatter(logging.Formatter):
def format(self, record): def format(self, record):
if record.name.startswith("lm_eval."): if record.name.startswith("lm_eval."):
...@@ -39,32 +106,27 @@ def setup_logging(verbosity=logging.INFO): ...@@ -39,32 +106,27 @@ def setup_logging(verbosity=logging.INFO):
datefmt="%Y-%m-%d:%H:%M:%S", datefmt="%Y-%m-%d:%H:%M:%S",
) )
log_level = os.environ.get("LOGLEVEL", verbosity) or verbosity # Check if handler already exists to prevent duplicates
has_stream_handler = any(
level_map = { isinstance(h, logging.StreamHandler) for h in logger.handlers
"DEBUG": logging.DEBUG, )
"INFO": logging.INFO, if not has_stream_handler:
"WARNING": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
log_level = level_map.get(str(log_level).upper(), logging.INFO)
if not logging.root.handlers:
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger.addHandler(handler)
# For CLI use, we disable propagation to avoid duplicate messages
logger.propagate = False
root_logger = logging.getLogger() # Set the logger level
root_logger.addHandler(handler) logger.setLevel(log_level)
root_logger.setLevel(log_level)
if log_level == logging.DEBUG: # Optionally suppress verbose third-party library logs
third_party_loggers = ["urllib3", "filelock", "fsspec"] if suppress_third_party and log_level == logging.DEBUG:
for logger_name in third_party_loggers: third_party_loggers = ["urllib3", "filelock", "fsspec"]
logging.getLogger(logger_name).setLevel(logging.INFO) for logger_name in third_party_loggers:
else: logging.getLogger(logger_name).setLevel(logging.INFO)
logging.getLogger().setLevel(log_level)
return logger
def hash_string(string: str) -> str: def hash_string(string: str) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment