Unverified Commit 4975002d authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Reorganize file utils (#16264)

* Split file_utils in several submodules

* Fixes

* Add back more objects

* More fixes

* Who exactly decided to import that from there?

* Second suggestion to code with code review

* Revert wront move

* Fix imports

* Adapt all imports

* Adapt all imports everywhere

* Revert this import, will fix in a separate commit
parent 71356034
......@@ -45,9 +45,8 @@ from transformers import (
Seq2SeqTrainingArguments,
set_seed,
)
from transformers.file_utils import is_offline_mode
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils import check_min_version, is_offline_mode
from transformers.utils.versions import require_version
......
......@@ -49,7 +49,7 @@ from transformers import (
get_scheduler,
set_seed,
)
from transformers.file_utils import get_full_repo_name, is_offline_mode
from transformers.utils import get_full_repo_name, is_offline_mode
from transformers.utils.versions import require_version
......
......@@ -25,8 +25,8 @@ from unittest.mock import patch
import torch
from transformers import ViTMAEForPreTraining, Wav2Vec2ForPreTraining
from transformers.file_utils import is_apex_available
from transformers.testing_utils import CaptureLogger, TestCasePlus, get_gpu_count, slow, torch_device
from transformers.utils import is_apex_available
SRC_DIRS = [
......
......@@ -40,7 +40,7 @@ from transformers import (
get_scheduler,
set_seed,
)
from transformers.file_utils import get_full_repo_name
from transformers.utils import get_full_repo_name
from transformers.utils.versions import require_version
......
......@@ -48,7 +48,7 @@ from transformers import (
get_scheduler,
set_seed,
)
from transformers.file_utils import get_full_repo_name
from transformers.utils import get_full_repo_name
from transformers.utils.versions import require_version
......
......@@ -50,7 +50,7 @@ from transformers import (
get_scheduler,
set_seed,
)
from transformers.file_utils import get_full_repo_name
from transformers.utils import get_full_repo_name
from transformers.utils.versions import require_version
......
......@@ -43,9 +43,8 @@ from transformers import (
create_optimizer,
set_seed,
)
from transformers.file_utils import PaddingStrategy
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import check_min_version
from transformers.utils import PaddingStrategy, check_min_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
......
......@@ -41,8 +41,7 @@ from transformers import (
TFTrainingArguments,
set_seed,
)
from transformers.file_utils import CONFIG_NAME, TF2_WEIGHTS_NAME
from transformers.utils import check_min_version
from transformers.utils import CONFIG_NAME, TF2_WEIGHTS_NAME, check_min_version
from utils_qa import postprocess_qa_predictions
......
......@@ -43,9 +43,8 @@ from transformers import (
create_optimizer,
set_seed,
)
from transformers.file_utils import is_offline_mode
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils import check_min_version, is_offline_mode
from transformers.utils.versions import require_version
......
......@@ -37,7 +37,7 @@ from transformers import (
TFTrainingArguments,
set_seed,
)
from transformers.file_utils import CONFIG_NAME, TF2_WEIGHTS_NAME
from transformers.utils import CONFIG_NAME, TF2_WEIGHTS_NAME
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" # Reduce the amount of console output from TF
......
......@@ -28,7 +28,7 @@ from typing import TYPE_CHECKING
# Check the dependencies satisfy the minimal versions required.
from . import dependency_versions_check
from .file_utils import (
from .utils import (
_LazyModule,
is_flax_available,
is_scatter_available,
......@@ -39,8 +39,8 @@ from .file_utils import (
is_tokenizers_available,
is_torch_available,
is_vision_available,
logging,
)
from .utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -94,39 +94,7 @@ _import_structure = {
"dynamic_module_utils": [],
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
"file_utils": [
"CONFIG_NAME",
"MODEL_CARD_NAME",
"PYTORCH_PRETRAINED_BERT_CACHE",
"PYTORCH_TRANSFORMERS_CACHE",
"SPIECE_UNDERLINE",
"TF2_WEIGHTS_NAME",
"TF_WEIGHTS_NAME",
"TRANSFORMERS_CACHE",
"WEIGHTS_NAME",
"TensorType",
"add_end_docstrings",
"add_start_docstrings",
"cached_path",
"is_apex_available",
"is_datasets_available",
"is_faiss_available",
"is_flax_available",
"is_phonemizer_available",
"is_psutil_available",
"is_py3nvml_available",
"is_pyctcdecode_available",
"is_scipy_available",
"is_sentencepiece_available",
"is_sklearn_available",
"is_speech_available",
"is_tf_available",
"is_timm_available",
"is_tokenizers_available",
"is_torch_available",
"is_torch_tpu_available",
"is_vision_available",
],
"file_utils": [],
"hf_argparser": ["HfArgumentParser"],
"integrations": [
"is_comet_available",
......@@ -147,8 +115,8 @@ _import_structure = {
"load_tf2_model_in_pytorch_model",
"load_tf2_weights_in_pytorch_model",
],
# Models
"models": [],
# Models
"models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"],
"models.auto": [
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
......@@ -396,7 +364,40 @@ _import_structure = {
"training_args": ["TrainingArguments"],
"training_args_seq2seq": ["Seq2SeqTrainingArguments"],
"training_args_tf": ["TFTrainingArguments"],
"utils": ["logging"],
"utils": [
"CONFIG_NAME",
"MODEL_CARD_NAME",
"PYTORCH_PRETRAINED_BERT_CACHE",
"PYTORCH_TRANSFORMERS_CACHE",
"SPIECE_UNDERLINE",
"TF2_WEIGHTS_NAME",
"TF_WEIGHTS_NAME",
"TRANSFORMERS_CACHE",
"WEIGHTS_NAME",
"TensorType",
"add_end_docstrings",
"add_start_docstrings",
"cached_path",
"is_apex_available",
"is_datasets_available",
"is_faiss_available",
"is_flax_available",
"is_phonemizer_available",
"is_psutil_available",
"is_py3nvml_available",
"is_pyctcdecode_available",
"is_scipy_available",
"is_sentencepiece_available",
"is_sklearn_available",
"is_speech_available",
"is_tf_available",
"is_timm_available",
"is_tokenizers_available",
"is_torch_available",
"is_torch_tpu_available",
"is_vision_available",
"logging",
],
}
# sentencepiece-backed objects
......@@ -2432,41 +2433,6 @@ if TYPE_CHECKING:
# Feature Extractor
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
# Files and general utilities
from .file_utils import (
CONFIG_NAME,
MODEL_CARD_NAME,
PYTORCH_PRETRAINED_BERT_CACHE,
PYTORCH_TRANSFORMERS_CACHE,
SPIECE_UNDERLINE,
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
TRANSFORMERS_CACHE,
WEIGHTS_NAME,
TensorType,
add_end_docstrings,
add_start_docstrings,
cached_path,
is_apex_available,
is_datasets_available,
is_faiss_available,
is_flax_available,
is_phonemizer_available,
is_psutil_available,
is_py3nvml_available,
is_pyctcdecode_available,
is_scipy_available,
is_sentencepiece_available,
is_sklearn_available,
is_speech_available,
is_tf_available,
is_timm_available,
is_tokenizers_available,
is_torch_available,
is_torch_tpu_available,
is_vision_available,
)
from .hf_argparser import HfArgumentParser
# Integrations
......@@ -2714,7 +2680,42 @@ if TYPE_CHECKING:
from .training_args import TrainingArguments
from .training_args_seq2seq import Seq2SeqTrainingArguments
from .training_args_tf import TFTrainingArguments
from .utils import logging
# Files and general utilities
from .utils import (
CONFIG_NAME,
MODEL_CARD_NAME,
PYTORCH_PRETRAINED_BERT_CACHE,
PYTORCH_TRANSFORMERS_CACHE,
SPIECE_UNDERLINE,
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
TRANSFORMERS_CACHE,
WEIGHTS_NAME,
TensorType,
add_end_docstrings,
add_start_docstrings,
cached_path,
is_apex_available,
is_datasets_available,
is_faiss_available,
is_flax_available,
is_phonemizer_available,
is_psutil_available,
is_py3nvml_available,
is_pyctcdecode_available,
is_scipy_available,
is_sentencepiece_available,
is_sklearn_available,
is_speech_available,
is_tf_available,
is_timm_available,
is_tokenizers_available,
is_torch_available,
is_torch_tpu_available,
is_vision_available,
logging,
)
if is_sentencepiece_available():
from .models.albert import AlbertTokenizer
......
......@@ -22,9 +22,8 @@ import timeit
from typing import Callable, Optional
from ..configuration_utils import PretrainedConfig
from ..file_utils import is_py3nvml_available, is_torch_available
from ..models.auto.modeling_auto import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING
from ..utils import logging
from ..utils import is_py3nvml_available, is_torch_available, logging
from .benchmark_utils import (
Benchmark,
Memory,
......
......@@ -17,8 +17,7 @@
from dataclasses import dataclass, field
from typing import Tuple
from ..file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from ..utils import logging
from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, torch_required
from .benchmark_args_utils import BenchmarkArguments
......
......@@ -17,8 +17,7 @@
from dataclasses import dataclass, field
from typing import Tuple
from ..file_utils import cached_property, is_tf_available, tf_required
from ..utils import logging
from ..utils import cached_property, is_tf_available, logging, tf_required
from .benchmark_args_utils import BenchmarkArguments
......
......@@ -24,9 +24,8 @@ from functools import wraps
from typing import Callable, Optional
from ..configuration_utils import PretrainedConfig
from ..file_utils import is_py3nvml_available, is_tf_available
from ..models.auto.modeling_tf_auto import TF_MODEL_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING
from ..utils import logging
from ..utils import is_py3nvml_available, is_tf_available, logging
from .benchmark_utils import (
Benchmark,
Memory,
......
......@@ -33,8 +33,7 @@ from typing import Callable, Iterable, List, NamedTuple, Optional, Union
from .. import AutoConfig, PretrainedConfig
from .. import __version__ as version
from ..file_utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available
from ..utils import logging
from ..utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available, logging
from .benchmark_args_utils import BenchmarkArguments
......
......@@ -25,8 +25,7 @@ from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union
import transformers.models.auto as auto_module
from transformers.models.auto.configuration_auto import model_type_to_module_name
from ..file_utils import is_flax_available, is_tf_available, is_torch_available
from ..utils import logging
from ..utils import is_flax_available, is_tf_available, is_torch_available, logging
from . import BaseTransformersCLICommand
......
......@@ -18,7 +18,7 @@ from argparse import ArgumentParser
import huggingface_hub
from .. import __version__ as version
from ..file_utils import is_flax_available, is_tf_available, is_torch_available
from ..utils import is_flax_available, is_tf_available, is_torch_available
from . import BaseTransformersCLICommand
......
......@@ -16,9 +16,8 @@ import os
from argparse import ArgumentParser, Namespace
from ..data import SingleSentenceClassificationProcessor as Processor
from ..file_utils import is_tf_available, is_torch_available
from ..pipelines import TextClassificationPipeline
from ..utils import logging
from ..utils import is_tf_available, is_torch_available, logging
from . import BaseTransformersCLICommand
......
......@@ -29,7 +29,7 @@ from requests import HTTPError
from . import __version__
from .dynamic_module_utils import custom_object_save
from .file_utils import (
from .utils import (
CONFIG_NAME,
EntryNotFoundError,
PushToHubMixin,
......@@ -41,8 +41,8 @@ from .file_utils import (
is_offline_mode,
is_remote_url,
is_torch_available,
logging,
)
from .utils import logging
logger = logging.get_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment