Unverified Commit 77abd1e7 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Centralize logging (#6434)



* Logging

* Style

* hf_logging > utils.logging

* Address @thomwolf's comments

* Update test

* Update src/transformers/benchmark/benchmark_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Revert bad change
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 461ae868
...@@ -17,8 +17,6 @@ else: ...@@ -17,8 +17,6 @@ else:
absl.logging.set_stderrthreshold("info") absl.logging.set_stderrthreshold("info")
absl.logging._warn_preinit_stderr = False absl.logging._warn_preinit_stderr = False
import logging
# Configurations # Configurations
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
...@@ -184,9 +182,10 @@ from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer ...@@ -184,9 +182,10 @@ from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
from .trainer_utils import EvalPrediction, set_seed from .trainer_utils import EvalPrediction, set_seed
from .training_args import TrainingArguments from .training_args import TrainingArguments
from .training_args_tf import TFTrainingArguments from .training_args_tf import TFTrainingArguments
from .utils import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_sklearn_available(): if is_sklearn_available():
......
import logging
import math import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from .utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
def swish(x): def swish(x):
......
...@@ -18,13 +18,13 @@ ...@@ -18,13 +18,13 @@
""" """
import logging
import timeit import timeit
from typing import Callable, Optional from typing import Callable, Optional
from ..configuration_utils import PretrainedConfig from ..configuration_utils import PretrainedConfig
from ..file_utils import is_py3nvml_available, is_torch_available from ..file_utils import is_py3nvml_available, is_torch_available
from ..modeling_auto import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING from ..modeling_auto import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING
from ..utils import logging
from .benchmark_utils import ( from .benchmark_utils import (
Benchmark, Benchmark,
Memory, Memory,
...@@ -45,7 +45,7 @@ if is_py3nvml_available(): ...@@ -45,7 +45,7 @@ if is_py3nvml_available():
import py3nvml.py3nvml as nvml import py3nvml.py3nvml as nvml
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
class PyTorchBenchmark(Benchmark): class PyTorchBenchmark(Benchmark):
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple from typing import Tuple
from ..file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required from ..file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from ..utils import logging
from .benchmark_args_utils import BenchmarkArguments from .benchmark_args_utils import BenchmarkArguments
...@@ -29,7 +29,7 @@ if is_torch_tpu_available(): ...@@ -29,7 +29,7 @@ if is_torch_tpu_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
@dataclass @dataclass
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple from typing import Tuple
from ..file_utils import cached_property, is_tf_available, tf_required from ..file_utils import cached_property, is_tf_available, tf_required
from ..utils import logging
from .benchmark_args_utils import BenchmarkArguments from .benchmark_args_utils import BenchmarkArguments
...@@ -26,7 +26,7 @@ if is_tf_available(): ...@@ -26,7 +26,7 @@ if is_tf_available():
import tensorflow as tf import tensorflow as tf
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
@dataclass @dataclass
......
...@@ -16,13 +16,14 @@ ...@@ -16,13 +16,14 @@
import dataclasses import dataclasses
import json import json
import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from time import time from time import time
from typing import List from typing import List
from ..utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
def list_field(default=None, metadata=None): def list_field(default=None, metadata=None):
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
""" """
import logging
import random import random
import timeit import timeit
from functools import wraps from functools import wraps
...@@ -27,6 +26,7 @@ from typing import Callable, Optional ...@@ -27,6 +26,7 @@ from typing import Callable, Optional
from ..configuration_utils import PretrainedConfig from ..configuration_utils import PretrainedConfig
from ..file_utils import is_py3nvml_available, is_tf_available from ..file_utils import is_py3nvml_available, is_tf_available
from ..modeling_tf_auto import TF_MODEL_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING from ..modeling_tf_auto import TF_MODEL_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING
from ..utils import logging
from .benchmark_utils import ( from .benchmark_utils import (
Benchmark, Benchmark,
Memory, Memory,
...@@ -46,7 +46,7 @@ if is_tf_available(): ...@@ -46,7 +46,7 @@ if is_tf_available():
if is_py3nvml_available(): if is_py3nvml_available():
import py3nvml.py3nvml as nvml import py3nvml.py3nvml as nvml
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool): def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool):
......
...@@ -7,7 +7,6 @@ Copyright by the AllenNLP authors. ...@@ -7,7 +7,6 @@ Copyright by the AllenNLP authors.
import copy import copy
import csv import csv
import linecache import linecache
import logging
import os import os
import platform import platform
import sys import sys
...@@ -22,6 +21,7 @@ from transformers import AutoConfig, PretrainedConfig ...@@ -22,6 +21,7 @@ from transformers import AutoConfig, PretrainedConfig
from transformers import __version__ as version from transformers import __version__ as version
from ..file_utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available from ..file_utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available
from ..utils import logging
from .benchmark_args_utils import BenchmarkArguments from .benchmark_args_utils import BenchmarkArguments
...@@ -43,7 +43,7 @@ else: ...@@ -43,7 +43,7 @@ else:
from signal import SIGKILL from signal import SIGKILL
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_is_memory_tracing_enabled = False _is_memory_tracing_enabled = False
...@@ -94,7 +94,7 @@ def separate_process_wrapper_fn(func: Callable[[], None], do_multi_processing: b ...@@ -94,7 +94,7 @@ def separate_process_wrapper_fn(func: Callable[[], None], do_multi_processing: b
return result return result
if do_multi_processing: if do_multi_processing:
logging.info("fFunction {func} is executed in its own process...") logger.info(f"Function {func} is executed in its own process...")
return multi_process_func return multi_process_func
else: else:
return func return func
......
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from logging import getLogger
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
from ..utils import logging
def convert_command_factory(args: Namespace): def convert_command_factory(args: Namespace):
""" """
...@@ -52,7 +53,7 @@ class ConvertCommand(BaseTransformersCLICommand): ...@@ -52,7 +53,7 @@ class ConvertCommand(BaseTransformersCLICommand):
finetuning_task_name: str, finetuning_task_name: str,
*args *args
): ):
self._logger = getLogger("transformers-cli/converting") self._logger = logging.get_logger("transformers-cli/converting")
self._logger.info("Loading model {}".format(model_type)) self._logger.info("Loading model {}".format(model_type))
self._model_type = model_type self._model_type = model_type
......
import logging
from argparse import ArgumentParser from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
from ..utils import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def try_infer_format_from_ext(path: str): def try_infer_format_from_ext(path: str):
......
import logging
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional from typing import Any, List, Optional
...@@ -6,6 +5,8 @@ from transformers import Pipeline ...@@ -6,6 +5,8 @@ from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline from transformers.pipelines import SUPPORTED_TASKS, pipeline
from ..utils import logging
try: try:
from fastapi import Body, FastAPI, HTTPException from fastapi import Body, FastAPI, HTTPException
......
import os import os
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from logging import getLogger
from transformers import SingleSentenceClassificationProcessor as Processor from transformers import SingleSentenceClassificationProcessor as Processor
from transformers import TextClassificationPipeline, is_tf_available, is_torch_available from transformers import TextClassificationPipeline, is_tf_available, is_torch_available
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
from ..utils import logging
if not is_tf_available() and not is_torch_available(): if not is_tf_available() and not is_torch_available():
raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training") raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
...@@ -76,7 +77,7 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -76,7 +77,7 @@ class TrainCommand(BaseTransformersCLICommand):
train_parser.set_defaults(func=train_command_factory) train_parser.set_defaults(func=train_command_factory)
def __init__(self, args: Namespace): def __init__(self, args: Namespace):
self.logger = getLogger("transformers-cli/training") self.logger = logging.get_logger("transformers-cli/training")
self.framework = "tf" if is_tf_available() else "torch" self.framework = "tf" if is_tf_available() else "torch"
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
""" Auto Config class. """ """ Auto Config class. """
import logging
from collections import OrderedDict from collections import OrderedDict
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
...@@ -45,9 +44,6 @@ from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP ...@@ -45,9 +44,6 @@ from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
logger = logging.getLogger(__name__)
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
(key, value) (key, value)
for pretrained_map in [ for pretrained_map in [
......
...@@ -14,14 +14,12 @@ ...@@ -14,14 +14,12 @@
# limitations under the License. # limitations under the License.
""" BART configuration """ """ BART configuration """
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import add_start_docstrings_to_callable from .file_utils import add_start_docstrings_to_callable
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/bart-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-base/config.json", "facebook/bart-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-base/config.json",
......
...@@ -15,13 +15,11 @@ ...@@ -15,13 +15,11 @@
# limitations under the License. # limitations under the License.
""" BERT model configuration """ """ BERT model configuration """
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
......
...@@ -15,13 +15,11 @@ ...@@ -15,13 +15,11 @@
# limitations under the License. # limitations under the License.
""" CamemBERT configuration """ """ CamemBERT configuration """
import logging
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json", "camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json",
......
...@@ -14,13 +14,11 @@ ...@@ -14,13 +14,11 @@
# limitations under the License. # limitations under the License.
""" Salesforce CTRL configuration """ """ Salesforce CTRL configuration """
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-config.json"} CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-config.json"}
......
...@@ -14,13 +14,11 @@ ...@@ -14,13 +14,11 @@
# limitations under the License. # limitations under the License.
""" DistilBERT model configuration """ """ DistilBERT model configuration """
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json", "distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
......
...@@ -14,13 +14,11 @@ ...@@ -14,13 +14,11 @@
# limitations under the License. # limitations under the License.
""" DPR model configuration """ """ DPR model configuration """
import logging
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = { DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-ctx_encoder-single-nq-base/config.json", "facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-ctx_encoder-single-nq-base/config.json",
......
...@@ -15,13 +15,11 @@ ...@@ -15,13 +15,11 @@
# limitations under the License. # limitations under the License.
""" ELECTRA model configuration """ """ ELECTRA model configuration """
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP = { ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/electra-small-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-generator/config.json", "google/electra-small-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-generator/config.json",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment