Unverified Commit 547d8dd8 authored by Sérgio Agostinho's avatar Sérgio Agostinho Committed by GitHub
Browse files

Don't touch nor send messages to the root logger. (#1380)



---------
Signed-off-by: default avatarSérgio Agostinho <sagostinho@nvidia.com>
parent f8eddcf9
...@@ -12,6 +12,8 @@ from importlib.metadata import version ...@@ -12,6 +12,8 @@ from importlib.metadata import version
from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension from transformer_engine.common import _get_sys_extension
_logger = logging.getLogger(__name__)
def _load_library(): def _load_library():
"""Load shared library with Transformer Engine C extensions""" """Load shared library with Transformer Engine C extensions"""
...@@ -36,7 +38,7 @@ def _load_library(): ...@@ -36,7 +38,7 @@ def _load_library():
if is_package_installed("transformer-engine-cu12"): if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name): if not is_package_installed(module_name):
logging.info( _logger.info(
"Could not find package %s. Install transformer-engine using 'pip" "Could not find package %s. Install transformer-engine using 'pip"
" install transformer-engine[jax]==VERSION'", " install transformer-engine[jax]==VERSION'",
module_name, module_name,
......
...@@ -19,6 +19,8 @@ import torch ...@@ -19,6 +19,8 @@ import torch
from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension from transformer_engine.common import _get_sys_extension
_logger = logging.getLogger(__name__)
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]: def torch_version() -> tuple[int, ...]:
...@@ -49,7 +51,7 @@ def _load_library(): ...@@ -49,7 +51,7 @@ def _load_library():
if is_package_installed("transformer-engine-cu12"): if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name): if not is_package_installed(module_name):
logging.info( _logger.info(
"Could not find package %s. Install transformer-engine using 'pip" "Could not find package %s. Install transformer-engine using 'pip"
" install transformer-engine[pytorch]==VERSION'", " install transformer-engine[pytorch]==VERSION'",
module_name, module_name,
......
...@@ -98,7 +98,7 @@ _log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2] ...@@ -98,7 +98,7 @@ _log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2]
_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") _formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s")
_stream_handler = logging.StreamHandler() _stream_handler = logging.StreamHandler()
_stream_handler.setFormatter(_formatter) _stream_handler.setFormatter(_formatter)
fa_logger = logging.getLogger() fa_logger = logging.getLogger(__name__)
fa_logger.setLevel(_log_level) fa_logger.setLevel(_log_level)
if not fa_logger.hasHandlers(): if not fa_logger.hasHandlers():
fa_logger.addHandler(_stream_handler) fa_logger.addHandler(_stream_handler)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment