Unverified Commit 4975002d authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Reorganize file utils (#16264)

* Split file_utils in several submodules

* Fixes

* Add back more objects

* More fixes

* Who exactly decided to import that from there?

* Second suggestion to code with code review

* Revert wront move

* Fix imports

* Adapt all imports

* Adapt all imports everywhere

* Revert this import, will fix in a separate commit
parent 71356034
...@@ -20,9 +20,9 @@ from typing import Dict, List, Optional, Tuple ...@@ -20,9 +20,9 @@ from typing import Dict, List, Optional, Tuple
from packaging.version import Version, parse from packaging.version import Version, parse
from transformers.file_utils import ModelOutput, is_tf_available, is_torch_available
from transformers.pipelines import Pipeline, pipeline from transformers.pipelines import Pipeline, pipeline
from transformers.tokenization_utils import BatchEncoding from transformers.tokenization_utils import BatchEncoding
from transformers.utils import ModelOutput, is_tf_available, is_torch_available
# This is the minimal required version to # This is the minimal required version to
......
...@@ -95,8 +95,7 @@ from . import ( ...@@ -95,8 +95,7 @@ from . import (
is_torch_available, is_torch_available,
load_pytorch_checkpoint_in_tf2_model, load_pytorch_checkpoint_in_tf2_model,
) )
from .file_utils import hf_bucket_url from .utils import hf_bucket_url, logging
from .utils import logging
if is_torch_available(): if is_torch_available():
......
...@@ -24,7 +24,7 @@ from typing import Dict, List, Tuple ...@@ -24,7 +24,7 @@ from typing import Dict, List, Tuple
from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE, Unigram, WordPiece from tokenizers.models import BPE, Unigram, WordPiece
from .file_utils import requires_backends from .utils import requires_backends
class SentencePieceExtractor: class SentencePieceExtractor:
......
...@@ -17,9 +17,9 @@ import warnings ...@@ -17,9 +17,9 @@ import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from ..file_utils import PaddingStrategy
from ..models.bert import BertTokenizer, BertTokenizerFast from ..models.bert import BertTokenizer, BertTokenizerFast
from ..tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase from ..tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
from ..utils import PaddingStrategy
InputDataClass = NewType("InputDataClass", Any) InputDataClass = NewType("InputDataClass", Any)
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import warnings import warnings
from ...file_utils import is_sklearn_available, requires_backends from ...utils import is_sklearn_available, requires_backends
if is_sklearn_available(): if is_sklearn_available():
......
...@@ -21,9 +21,8 @@ from dataclasses import asdict ...@@ -21,9 +21,8 @@ from dataclasses import asdict
from enum import Enum from enum import Enum
from typing import List, Optional, Union from typing import List, Optional, Union
from ...file_utils import is_tf_available
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging from ...utils import is_tf_available, logging
from .utils import DataProcessor, InputExample, InputFeatures from .utils import DataProcessor, InputExample, InputFeatures
......
...@@ -20,10 +20,9 @@ from multiprocessing import Pool, cpu_count ...@@ -20,10 +20,9 @@ from multiprocessing import Pool, cpu_count
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from ...file_utils import is_tf_available, is_torch_available
from ...models.bert.tokenization_bert import whitespace_tokenize from ...models.bert.tokenization_bert import whitespace_tokenize
from ...tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TruncationStrategy from ...tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TruncationStrategy
from ...utils import logging from ...utils import is_tf_available, is_torch_available, logging
from .utils import DataProcessor from .utils import DataProcessor
......
...@@ -20,8 +20,7 @@ import json ...@@ -20,8 +20,7 @@ import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from typing import List, Optional, Union
from ...file_utils import is_tf_available, is_torch_available from ...utils import is_tf_available, is_torch_available, logging
from ...utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -17,8 +17,8 @@ import unittest ...@@ -17,8 +17,8 @@ import unittest
import timeout_decorator import timeout_decorator
from ..file_utils import cached_property, is_torch_available
from ..testing_utils import require_torch from ..testing_utils import require_torch
from ..utils import cached_property, is_torch_available
if is_torch_available(): if is_torch_available():
......
...@@ -14,8 +14,7 @@ ...@@ -14,8 +14,7 @@
import collections import collections
from .file_utils import ExplicitEnum, is_torch_available from .utils import ExplicitEnum, is_torch_available, logging
from .utils import logging
if is_torch_available(): if is_torch_available():
......
...@@ -23,8 +23,7 @@ from copy import deepcopy ...@@ -23,8 +23,7 @@ from copy import deepcopy
from functools import partialmethod from functools import partialmethod
from .dependency_versions_check import dep_version_check from .dependency_versions_check import dep_version_check
from .file_utils import is_torch_available from .utils import is_torch_available, logging
from .utils import logging
if is_torch_available(): if is_torch_available():
......
...@@ -33,7 +33,7 @@ for pkg in pkgs_to_check_at_runtime: ...@@ -33,7 +33,7 @@ for pkg in pkgs_to_check_at_runtime:
if pkg in deps: if pkg in deps:
if pkg == "tokenizers": if pkg == "tokenizers":
# must be loaded here, or else tqdm check may fail # must be loaded here, or else tqdm check may fail
from .file_utils import is_tokenizers_available from .utils import is_tokenizers_available
if not is_tokenizers_available(): if not is_tokenizers_available():
continue # not required, check version only if installed continue # not required, check version only if installed
......
...@@ -24,8 +24,14 @@ from typing import Dict, Optional, Union ...@@ -24,8 +24,14 @@ from typing import Dict, Optional, Union
from huggingface_hub import HfFolder, model_info from huggingface_hub import HfFolder, model_info
from .file_utils import HF_MODULES_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, cached_path, hf_bucket_url, is_offline_mode from .utils import (
from .utils import logging HF_MODULES_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME,
cached_path,
hf_bucket_url,
is_offline_mode,
logging,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -20,16 +20,8 @@ from typing import Dict, List, Optional, Union ...@@ -20,16 +20,8 @@ from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from .file_utils import ( from .utils import PaddingStrategy, TensorType, is_tf_available, is_torch_available, logging, to_numpy
PaddingStrategy, from .utils.generic import _is_tensorflow, _is_torch
TensorType,
_is_tensorflow,
_is_torch,
is_tf_available,
is_torch_available,
to_numpy,
)
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -27,16 +27,13 @@ import numpy as np ...@@ -27,16 +27,13 @@ import numpy as np
from requests import HTTPError from requests import HTTPError
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .file_utils import ( from .utils import (
FEATURE_EXTRACTOR_NAME, FEATURE_EXTRACTOR_NAME,
EntryNotFoundError, EntryNotFoundError,
PushToHubMixin, PushToHubMixin,
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
TensorType, TensorType,
_is_jax,
_is_numpy,
_is_torch_device,
cached_path, cached_path,
copy_func, copy_func,
hf_bucket_url, hf_bucket_url,
...@@ -45,9 +42,10 @@ from .file_utils import ( ...@@ -45,9 +42,10 @@ from .file_utils import (
is_remote_url, is_remote_url,
is_tf_available, is_tf_available,
is_torch_available, is_torch_available,
logging,
torch_required, torch_required,
) )
from .utils import logging from .utils.generic import _is_jax, _is_numpy, _is_torch_device
if TYPE_CHECKING: if TYPE_CHECKING:
......
This diff is collapsed.
...@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple ...@@ -21,8 +21,8 @@ from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from .file_utils import add_start_docstrings
from .generation_beam_constraints import Constraint, ConstraintListState from .generation_beam_constraints import Constraint, ConstraintListState
from .utils import add_start_docstrings
PROCESS_INPUTS_DOCSTRING = r""" PROCESS_INPUTS_DOCSTRING = r"""
......
...@@ -19,7 +19,7 @@ import jax ...@@ -19,7 +19,7 @@ import jax
import jax.lax as lax import jax.lax as lax
import jax.numpy as jnp import jax.numpy as jnp
from .file_utils import add_start_docstrings from .utils import add_start_docstrings
from .utils.logging import get_logger from .utils.logging import get_logger
......
...@@ -25,7 +25,6 @@ import jax ...@@ -25,7 +25,6 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import lax from jax import lax
from .file_utils import ModelOutput
from .generation_flax_logits_process import ( from .generation_flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor, FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor,
...@@ -35,7 +34,7 @@ from .generation_flax_logits_process import ( ...@@ -35,7 +34,7 @@ from .generation_flax_logits_process import (
FlaxTopKLogitsWarper, FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper, FlaxTopPLogitsWarper,
) )
from .utils import logging from .utils import ModelOutput, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -20,7 +20,7 @@ from typing import Callable, Iterable, List, Optional, Tuple ...@@ -20,7 +20,7 @@ from typing import Callable, Iterable, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from .file_utils import add_start_docstrings from .utils import add_start_docstrings
from .utils.logging import get_logger from .utils.logging import get_logger
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment