Unverified Commit 77abd1e7 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Centralize logging (#6434)



* Logging

* Style

* hf_logging > utils.logging

* Address @thomwolf's comments

* Update test

* Update src/transformers/benchmark/benchmark_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Revert bad change
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 461ae868
......@@ -17,8 +17,6 @@ else:
absl.logging.set_stderrthreshold("info")
absl.logging._warn_preinit_stderr = False
import logging
# Configurations
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
......@@ -184,9 +182,10 @@ from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
from .trainer_utils import EvalPrediction, set_seed
from .training_args import TrainingArguments
from .training_args_tf import TFTrainingArguments
from .utils import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_sklearn_available():
......
import logging
import math
import torch
import torch.nn.functional as F
from .utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
def swish(x):
......
......@@ -18,13 +18,13 @@
"""
import logging
import timeit
from typing import Callable, Optional
from ..configuration_utils import PretrainedConfig
from ..file_utils import is_py3nvml_available, is_torch_available
from ..modeling_auto import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING
from ..utils import logging
from .benchmark_utils import (
Benchmark,
Memory,
......@@ -45,7 +45,7 @@ if is_py3nvml_available():
import py3nvml.py3nvml as nvml
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
class PyTorchBenchmark(Benchmark):
......
......@@ -14,11 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass, field
from typing import Tuple
from ..file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from ..utils import logging
from .benchmark_args_utils import BenchmarkArguments
......@@ -29,7 +29,7 @@ if is_torch_tpu_available():
import torch_xla.core.xla_model as xm
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
@dataclass
......
......@@ -14,11 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass, field
from typing import Tuple
from ..file_utils import cached_property, is_tf_available, tf_required
from ..utils import logging
from .benchmark_args_utils import BenchmarkArguments
......@@ -26,7 +26,7 @@ if is_tf_available():
import tensorflow as tf
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
@dataclass
......
......@@ -16,13 +16,14 @@
import dataclasses
import json
import logging
from dataclasses import dataclass, field
from time import time
from typing import List
from ..utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
def list_field(default=None, metadata=None):
......
......@@ -18,7 +18,6 @@
"""
import logging
import random
import timeit
from functools import wraps
......@@ -27,6 +26,7 @@ from typing import Callable, Optional
from ..configuration_utils import PretrainedConfig
from ..file_utils import is_py3nvml_available, is_tf_available
from ..modeling_tf_auto import TF_MODEL_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING
from ..utils import logging
from .benchmark_utils import (
Benchmark,
Memory,
......@@ -46,7 +46,7 @@ if is_tf_available():
if is_py3nvml_available():
import py3nvml.py3nvml as nvml
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool):
......
......@@ -7,7 +7,6 @@ Copyright by the AllenNLP authors.
import copy
import csv
import linecache
import logging
import os
import platform
import sys
......@@ -22,6 +21,7 @@ from transformers import AutoConfig, PretrainedConfig
from transformers import __version__ as version
from ..file_utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available
from ..utils import logging
from .benchmark_args_utils import BenchmarkArguments
......@@ -43,7 +43,7 @@ else:
from signal import SIGKILL
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_is_memory_tracing_enabled = False
......@@ -94,7 +94,7 @@ def separate_process_wrapper_fn(func: Callable[[], None], do_multi_processing: b
return result
if do_multi_processing:
logging.info("fFunction {func} is executed in its own process...")
logger.info(f"Function {func} is executed in its own process...")
return multi_process_func
else:
return func
......
from argparse import ArgumentParser, Namespace
from logging import getLogger
from transformers.commands import BaseTransformersCLICommand
from ..utils import logging
def convert_command_factory(args: Namespace):
"""
......@@ -52,7 +53,7 @@ class ConvertCommand(BaseTransformersCLICommand):
finetuning_task_name: str,
*args
):
self._logger = getLogger("transformers-cli/converting")
self._logger = logging.get_logger("transformers-cli/converting")
self._logger.info("Loading model {}".format(model_type))
self._model_type = model_type
......
import logging
from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
from ..utils import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def try_infer_format_from_ext(path: str):
......
import logging
from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional
......@@ -6,6 +5,8 @@ from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline
from ..utils import logging
try:
from fastapi import Body, FastAPI, HTTPException
......
import os
from argparse import ArgumentParser, Namespace
from logging import getLogger
from transformers import SingleSentenceClassificationProcessor as Processor
from transformers import TextClassificationPipeline, is_tf_available, is_torch_available
from transformers.commands import BaseTransformersCLICommand
from ..utils import logging
if not is_tf_available() and not is_torch_available():
raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
......@@ -76,7 +77,7 @@ class TrainCommand(BaseTransformersCLICommand):
train_parser.set_defaults(func=train_command_factory)
def __init__(self, args: Namespace):
self.logger = getLogger("transformers-cli/training")
self.logger = logging.get_logger("transformers-cli/training")
self.framework = "tf" if is_tf_available() else "torch"
......
......@@ -15,7 +15,6 @@
""" Auto Config class. """
import logging
from collections import OrderedDict
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
......@@ -45,9 +44,6 @@ from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
logger = logging.getLogger(__name__)
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
(key, value)
for pretrained_map in [
......
......@@ -14,14 +14,12 @@
# limitations under the License.
""" BART configuration """
import logging
from .configuration_utils import PretrainedConfig
from .file_utils import add_start_docstrings_to_callable
from .utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/bart-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-base/config.json",
......
......@@ -15,13 +15,11 @@
# limitations under the License.
""" BERT model configuration """
import logging
from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
......
......@@ -15,13 +15,11 @@
# limitations under the License.
""" CamemBERT configuration """
import logging
from .configuration_roberta import RobertaConfig
from .utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json",
......
......@@ -14,13 +14,11 @@
# limitations under the License.
""" Salesforce CTRL configuration """
import logging
from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-config.json"}
......
......@@ -14,13 +14,11 @@
# limitations under the License.
""" DistilBERT model configuration """
import logging
from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
......
......@@ -14,13 +14,11 @@
# limitations under the License.
""" DPR model configuration """
import logging
from .configuration_bert import BertConfig
from .utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-ctx_encoder-single-nq-base/config.json",
......
......@@ -15,13 +15,11 @@
# limitations under the License.
""" ELECTRA model configuration """
import logging
from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/electra-small-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-generator/config.json",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment