Unverified Commit 49c61a4a authored by Philipp Schmid's avatar Philipp Schmid Committed by GitHub
Browse files

Extend trainer logging for sm (#10633)

* renamed logging to hf_logging

* changed logging from hf_logging to logging and loggin to native_logging

* removed everything trying to fix import Trainer error

* adding imports again

* added custom add_handler function to logging.py

* make style

* added remove_handler

* added another conditional to assert
parent 1aa9c13f
...@@ -23,8 +23,10 @@ import math ...@@ -23,8 +23,10 @@ import math
import os import os
import re import re
import shutil import shutil
import sys
import time import time
import warnings import warnings
from logging import StreamHandler
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
...@@ -59,6 +61,7 @@ from .file_utils import ( ...@@ -59,6 +61,7 @@ from .file_utils import (
is_in_notebook, is_in_notebook,
is_sagemaker_distributed_available, is_sagemaker_distributed_available,
is_torch_tpu_available, is_torch_tpu_available,
is_training_run_on_sagemaker,
) )
from .modeling_utils import PreTrainedModel, unwrap_model from .modeling_utils import PreTrainedModel, unwrap_model
from .optimization import Adafactor, AdamW, get_scheduler from .optimization import Adafactor, AdamW, get_scheduler
...@@ -149,6 +152,10 @@ if is_sagemaker_distributed_available(): ...@@ -149,6 +152,10 @@ if is_sagemaker_distributed_available():
else: else:
import torch.distributed as dist import torch.distributed as dist
if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout))
if TYPE_CHECKING: if TYPE_CHECKING:
import optuna import optuna
......
...@@ -195,6 +195,24 @@ def enable_default_handler() -> None: ...@@ -195,6 +195,24 @@ def enable_default_handler() -> None:
_get_library_root_logger().addHandler(_default_handler) _get_library_root_logger().addHandler(_default_handler)
def add_handler(handler: logging.Handler) -> None:
"""adds a handler to the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert handler is not None
_get_library_root_logger().addHandler(handler)
def remove_handler(handler: logging.Handler) -> None:
"""removes given handler from the HuggingFace Transformers's root logger."""
_configure_library_root_logger()
assert handler is not None and handler not in _get_library_root_logger().handlers
_get_library_root_logger().removeHandler(handler)
def disable_propagation() -> None: def disable_propagation() -> None:
""" """
Disable propagation of the library log outputs. Note that log propagation is disabled by default. Disable propagation of the library log outputs. Note that log propagation is disabled by default.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment