Unverified Commit c65a2e33 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python] add type hints to logging functions in basic.py (#4527)

* [python] add type hints to logging functions in basic.py

* add hints on wrapper
parent 73f7d5d6
...@@ -12,7 +12,7 @@ from os import SEEK_END ...@@ -12,7 +12,7 @@ from os import SEEK_END
from os.path import getsize from os.path import getsize
from pathlib import Path from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
...@@ -34,17 +34,17 @@ def _get_sample_count(total_nrow: int, params: str): ...@@ -34,17 +34,17 @@ def _get_sample_count(total_nrow: int, params: str):
class _DummyLogger: class _DummyLogger:
def info(self, msg): def info(self, msg: str) -> None:
print(msg) print(msg)
def warning(self, msg): def warning(self, msg: str) -> None:
warnings.warn(msg, stacklevel=3) warnings.warn(msg, stacklevel=3)
_LOGGER = _DummyLogger() _LOGGER: Union[_DummyLogger, Logger] = _DummyLogger()
def register_logger(logger): def register_logger(logger: Logger) -> None:
"""Register custom logger. """Register custom logger.
Parameters Parameters
...@@ -58,12 +58,12 @@ def register_logger(logger): ...@@ -58,12 +58,12 @@ def register_logger(logger):
_LOGGER = logger _LOGGER = logger
def _normalize_native_string(func): def _normalize_native_string(func: Callable[[str], None]) -> Callable[[str], None]:
"""Join log messages from native library which come by chunks.""" """Join log messages from native library which come by chunks."""
msg_normalized = [] msg_normalized: List[str] = []
@wraps(func) @wraps(func)
def wrapper(msg): def wrapper(msg: str) -> None:
nonlocal msg_normalized nonlocal msg_normalized
if msg.strip() == '': if msg.strip() == '':
msg = ''.join(msg_normalized) msg = ''.join(msg_normalized)
...@@ -75,20 +75,20 @@ def _normalize_native_string(func): ...@@ -75,20 +75,20 @@ def _normalize_native_string(func):
return wrapper return wrapper
def _log_info(msg): def _log_info(msg: str) -> None:
_LOGGER.info(msg) _LOGGER.info(msg)
def _log_warning(msg): def _log_warning(msg: str) -> None:
_LOGGER.warning(msg) _LOGGER.warning(msg)
@_normalize_native_string @_normalize_native_string
def _log_native(msg): def _log_native(msg: str) -> None:
_LOGGER.info(msg) _LOGGER.info(msg)
def _log_callback(msg): def _log_callback(msg: bytes) -> None:
"""Redirect logs from native library into Python.""" """Redirect logs from native library into Python."""
_log_native(str(msg.decode('utf-8'))) _log_native(str(msg.decode('utf-8')))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment