Unverified Commit 60e72d5f authored by RustingSword's avatar RustingSword Committed by GitHub
Browse files

[python] allow to register any custom logger (fixes #4783) (#4880)



* [python] allow to register any custom logger

* allow customizable logging method name; add unit test

* [python] allow to register any custom logger

* allow customizable logging method name; add unit test

* update tests

* fix lint error

* remove unused method

* fix docstring style
Co-authored-by: default avatargongxudong <gongxudong@kuaishou.com>
parent d163c2c1
......@@ -41,21 +41,37 @@ class _DummyLogger:
warnings.warn(msg, stacklevel=3)
_LOGGER: Union[_DummyLogger, Logger] = _DummyLogger()
_LOGGER: Any = _DummyLogger()
_INFO_METHOD_NAME = "info"
_WARNING_METHOD_NAME = "warning"
def register_logger(logger: Logger) -> None:
def register_logger(
logger: Any, info_method_name: str = "info", warning_method_name: str = "warning"
) -> None:
"""Register custom logger.
Parameters
----------
logger : logging.Logger
logger : Any
Custom logger.
info_method_name : str, optional (default="info")
Method used to log info messages.
warning_method_name : str, optional (default="warning")
Method used to log warning messages.
"""
if not isinstance(logger, Logger):
raise TypeError("Logger should inherit logging.Logger class")
global _LOGGER
def _has_method(logger: Any, method_name: str) -> bool:
return callable(getattr(logger, method_name, None))
if not _has_method(logger, info_method_name) or not _has_method(logger, warning_method_name):
raise TypeError(
f"Logger must provide '{info_method_name}' and '{warning_method_name}' method"
)
global _LOGGER, _INFO_METHOD_NAME, _WARNING_METHOD_NAME
_LOGGER = logger
_INFO_METHOD_NAME = info_method_name
_WARNING_METHOD_NAME = warning_method_name
def _normalize_native_string(func: Callable[[str], None]) -> Callable[[str], None]:
......@@ -76,16 +92,16 @@ def _normalize_native_string(func: Callable[[str], None]) -> Callable[[str], Non
def _log_info(msg: str) -> None:
_LOGGER.info(msg)
getattr(_LOGGER, _INFO_METHOD_NAME)(msg)
def _log_warning(msg: str) -> None:
_LOGGER.warning(msg)
getattr(_LOGGER, _WARNING_METHOD_NAME)(msg)
@_normalize_native_string
def _log_native(msg: str) -> None:
_LOGGER.info(msg)
getattr(_LOGGER, _INFO_METHOD_NAME)(msg)
def _log_callback(msg: bytes) -> None:
......
......@@ -2,6 +2,7 @@
import logging
import numpy as np
import pytest
import lightgbm as lgb
......@@ -97,3 +98,70 @@ WARNING | More than one metric available, picking one to plot.
actual_log_wo_gpu_stuff.append(line)
assert "\n".join(actual_log_wo_gpu_stuff) == expected_log
def test_register_invalid_logger():
class LoggerWithoutInfoMethod:
def warning(self, msg: str) -> None:
print(msg)
class LoggerWithoutWarningMethod:
def info(self, msg: str) -> None:
print(msg)
class LoggerWithAttributeNotCallable:
def __init__(self):
self.info = 1
self.warning = 2
expected_error_message = "Logger must provide 'info' and 'warning' method"
with pytest.raises(TypeError, match=expected_error_message):
lgb.register_logger(LoggerWithoutInfoMethod())
with pytest.raises(TypeError, match=expected_error_message):
lgb.register_logger(LoggerWithoutWarningMethod())
with pytest.raises(TypeError, match=expected_error_message):
lgb.register_logger(LoggerWithAttributeNotCallable())
def test_register_custom_logger():
logged_messages = []
class CustomLogger:
def custom_info(self, msg: str) -> None:
logged_messages.append(msg)
def custom_warning(self, msg: str) -> None:
logged_messages.append(msg)
custom_logger = CustomLogger()
lgb.register_logger(
custom_logger,
info_method_name="custom_info",
warning_method_name="custom_warning"
)
lgb.basic._log_info("info message")
lgb.basic._log_warning("warning message")
expected_log = ["info message", "warning message"]
assert logged_messages == expected_log
logged_messages = []
X = np.array([[1, 2, 3],
[1, 2, 4],
[1, 2, 4],
[1, 2, 3]],
dtype=np.float32)
y = np.array([0, 1, 1, 0])
lgb_data = lgb.Dataset(X, y)
lgb.train(
{'objective': 'binary', 'metric': 'auc'},
lgb_data,
num_boost_round=10,
valid_sets=[lgb_data],
categorical_feature=[1]
)
assert logged_messages, "custom logger was not called"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment