Unverified Commit 4975002d authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Reorganize file utils (#16264)

* Split file_utils in several submodules

* Fixes

* Add back more objects

* More fixes

* Who exactly decided to import that from there?

* Second suggestion to code with code review

* Revert wront move

* Fix imports

* Adapt all imports

* Adapt all imports everywhere

* Revert this import, will fix in a separate commit
parent 71356034
......@@ -20,9 +20,9 @@ from typing import Dict, List, Optional, Tuple
from packaging.version import Version, parse
from transformers.file_utils import ModelOutput, is_tf_available, is_torch_available
from transformers.pipelines import Pipeline, pipeline
from transformers.tokenization_utils import BatchEncoding
from transformers.utils import ModelOutput, is_tf_available, is_torch_available
# This is the minimal required version to
......
......@@ -95,8 +95,7 @@ from . import (
is_torch_available,
load_pytorch_checkpoint_in_tf2_model,
)
from .file_utils import hf_bucket_url
from .utils import logging
from .utils import hf_bucket_url, logging
if is_torch_available():
......
......@@ -24,7 +24,7 @@ from typing import Dict, List, Tuple
from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE, Unigram, WordPiece
from .file_utils import requires_backends
from .utils import requires_backends
class SentencePieceExtractor:
......
......@@ -17,9 +17,9 @@ import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from ..file_utils import PaddingStrategy
from ..models.bert import BertTokenizer, BertTokenizerFast
from ..tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
from ..utils import PaddingStrategy
InputDataClass = NewType("InputDataClass", Any)
......
......@@ -16,7 +16,7 @@
import warnings
from ...file_utils import is_sklearn_available, requires_backends
from ...utils import is_sklearn_available, requires_backends
if is_sklearn_available():
......
......@@ -21,9 +21,8 @@ from dataclasses import asdict
from enum import Enum
from typing import List, Optional, Union
from ...file_utils import is_tf_available
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
from ...utils import is_tf_available, logging
from .utils import DataProcessor, InputExample, InputFeatures
......
......@@ -20,10 +20,9 @@ from multiprocessing import Pool, cpu_count
import numpy as np
from tqdm import tqdm
from ...file_utils import is_tf_available, is_torch_available
from ...models.bert.tokenization_bert import whitespace_tokenize
from ...tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TruncationStrategy
from ...utils import logging
from ...utils import is_tf_available, is_torch_available, logging
from .utils import DataProcessor
......
......@@ -20,8 +20,7 @@ import json
from dataclasses import dataclass
from typing import List, Optional, Union
from ...file_utils import is_tf_available, is_torch_available
from ...utils import logging
from ...utils import is_tf_available, is_torch_available, logging
logger = logging.get_logger(__name__)
......
......@@ -17,8 +17,8 @@ import unittest
import timeout_decorator
from ..file_utils import cached_property, is_torch_available
from ..testing_utils import require_torch
from ..utils import cached_property, is_torch_available
if is_torch_available():
......
......@@ -14,8 +14,7 @@
import collections
from .file_utils import ExplicitEnum, is_torch_available
from .utils import logging
from .utils import ExplicitEnum, is_torch_available, logging
if is_torch_available():
......
......@@ -23,8 +23,7 @@ from copy import deepcopy
from functools import partialmethod
from .dependency_versions_check import dep_version_check
from .file_utils import is_torch_available
from .utils import logging
from .utils import is_torch_available, logging
if is_torch_available():
......
......@@ -33,7 +33,7 @@ for pkg in pkgs_to_check_at_runtime:
if pkg in deps:
if pkg == "tokenizers":
# must be loaded here, or else tqdm check may fail
from .file_utils import is_tokenizers_available
from .utils import is_tokenizers_available
if not is_tokenizers_available():
continue # not required, check version only if installed
......
......@@ -24,8 +24,14 @@ from typing import Dict, Optional, Union
from huggingface_hub import HfFolder, model_info
from .file_utils import HF_MODULES_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, cached_path, hf_bucket_url, is_offline_mode
from .utils import logging
from .utils import (
HF_MODULES_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME,
cached_path,
hf_bucket_url,
is_offline_mode,
logging,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -20,16 +20,8 @@ from typing import Dict, List, Optional, Union
import numpy as np
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from .file_utils import (
PaddingStrategy,
TensorType,
_is_tensorflow,
_is_torch,
is_tf_available,
is_torch_available,
to_numpy,
)
from .utils import logging
from .utils import PaddingStrategy, TensorType, is_tf_available, is_torch_available, logging, to_numpy
from .utils.generic import _is_tensorflow, _is_torch
logger = logging.get_logger(__name__)
......
......@@ -27,16 +27,13 @@ import numpy as np
from requests import HTTPError
from .dynamic_module_utils import custom_object_save
from .file_utils import (
from .utils import (
FEATURE_EXTRACTOR_NAME,
EntryNotFoundError,
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
TensorType,
_is_jax,
_is_numpy,
_is_torch_device,
cached_path,
copy_func,
hf_bucket_url,
......@@ -45,9 +42,10 @@ from .file_utils import (
is_remote_url,
is_tf_available,
is_torch_available,
logging,
torch_required,
)
from .utils import logging
from .utils.generic import _is_jax, _is_numpy, _is_torch_device
if TYPE_CHECKING:
......
This diff is collapsed.
......@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple
import numpy as np
import torch
from .file_utils import add_start_docstrings
from .generation_beam_constraints import Constraint, ConstraintListState
from .utils import add_start_docstrings
PROCESS_INPUTS_DOCSTRING = r"""
......
......@@ -19,7 +19,7 @@ import jax
import jax.lax as lax
import jax.numpy as jnp
from .file_utils import add_start_docstrings
from .utils import add_start_docstrings
from .utils.logging import get_logger
......
......@@ -25,7 +25,6 @@ import jax
import jax.numpy as jnp
from jax import lax
from .file_utils import ModelOutput
from .generation_flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
......@@ -35,7 +34,7 @@ from .generation_flax_logits_process import (
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
)
from .utils import logging
from .utils import ModelOutput, logging
logger = logging.get_logger(__name__)
......
......@@ -20,7 +20,7 @@ from typing import Callable, Iterable, List, Optional, Tuple
import numpy as np
import torch
from .file_utils import add_start_docstrings
from .utils import add_start_docstrings
from .utils.logging import get_logger
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment