Unverified Commit 4975002d authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Reorganize file utils (#16264)

* Split file_utils in several submodules

* Fixes

* Add back more objects

* More fixes

* Who exactly decided to import that from there?

* Second suggestion to code with code review

* Revert wront move

* Fix imports

* Adapt all imports

* Adapt all imports everywhere

* Revert this import, will fix in a separate commit
parent 71356034
...@@ -6,7 +6,7 @@ from typing import Optional ...@@ -6,7 +6,7 @@ from typing import Optional
import torch import torch
from .file_utils import add_start_docstrings from .utils import add_start_docstrings
STOPPING_CRITERIA_INPUTS_DOCSTRING = r""" STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
......
...@@ -19,8 +19,8 @@ from typing import List ...@@ -19,8 +19,8 @@ from typing import List
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from .file_utils import add_start_docstrings
from .tf_utils import set_tensor_by_indices_to_value from .tf_utils import set_tensor_by_indices_to_value
from .utils import add_start_docstrings
from .utils.logging import get_logger from .utils.logging import get_logger
......
...@@ -21,7 +21,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -21,7 +21,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from .file_utils import ModelOutput
from .generation_tf_logits_process import ( from .generation_tf_logits_process import (
TFLogitsProcessorList, TFLogitsProcessorList,
TFMinLengthLogitsProcessor, TFMinLengthLogitsProcessor,
...@@ -33,7 +32,7 @@ from .generation_tf_logits_process import ( ...@@ -33,7 +32,7 @@ from .generation_tf_logits_process import (
TFTopPLogitsWarper, TFTopPLogitsWarper,
) )
from .tf_utils import set_tensor_by_indices_to_value, shape_list from .tf_utils import set_tensor_by_indices_to_value, shape_list
from .utils import logging from .utils import ModelOutput, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -23,7 +23,6 @@ import torch ...@@ -23,7 +23,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
from .file_utils import ModelOutput
from .generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint from .generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .generation_logits_process import ( from .generation_logits_process import (
...@@ -52,7 +51,7 @@ from .generation_stopping_criteria import ( ...@@ -52,7 +51,7 @@ from .generation_stopping_criteria import (
validate_stopping_criteria, validate_stopping_criteria,
) )
from .pytorch_utils import torch_int_div from .pytorch_utils import torch_int_div
from .utils import logging from .utils import ModelOutput, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -22,7 +22,8 @@ import PIL.ImageOps ...@@ -22,7 +22,8 @@ import PIL.ImageOps
import requests import requests
from .file_utils import _is_torch, is_torch_available from .utils import is_torch_available
from .utils.generic import _is_torch
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406] IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
......
...@@ -22,8 +22,7 @@ import sys ...@@ -22,8 +22,7 @@ import sys
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from .file_utils import is_datasets_available from .utils import is_datasets_available, logging
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -44,9 +43,9 @@ if _has_comet: ...@@ -44,9 +43,9 @@ if _has_comet:
except (ImportError, ValueError): except (ImportError, ValueError):
_has_comet = False _has_comet = False
from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
from .trainer_callback import ProgressCallback, TrainerCallback # noqa: E402 from .trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402 from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
from .utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
# Integration functions: # Integration functions:
......
...@@ -12,8 +12,8 @@ from tensorflow.keras.callbacks import Callback ...@@ -12,8 +12,8 @@ from tensorflow.keras.callbacks import Callback
from huggingface_hub import Repository from huggingface_hub import Repository
from . import IntervalStrategy, PreTrainedTokenizerBase from . import IntervalStrategy, PreTrainedTokenizerBase
from .file_utils import get_full_repo_name
from .modelcard import TrainingSummary from .modelcard import TrainingSummary
from .utils import get_full_repo_name
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -28,20 +28,6 @@ import yaml ...@@ -28,20 +28,6 @@ import yaml
from huggingface_hub import model_info from huggingface_hub import model_info
from . import __version__ from . import __version__
from .file_utils import (
CONFIG_NAME,
MODEL_CARD_NAME,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME,
cached_path,
hf_bucket_url,
is_datasets_available,
is_offline_mode,
is_remote_url,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
from .models.auto.modeling_auto import ( from .models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
...@@ -56,7 +42,21 @@ from .models.auto.modeling_auto import ( ...@@ -56,7 +42,21 @@ from .models.auto.modeling_auto import (
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
) )
from .training_args import ParallelMode from .training_args import ParallelMode
from .utils import logging from .utils import (
CONFIG_NAME,
MODEL_CARD_NAME,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME,
cached_path,
hf_bucket_url,
is_datasets_available,
is_offline_mode,
is_remote_url,
is_tf_available,
is_tokenizers_available,
is_torch_available,
logging,
)
TASK_MAPPING = { TASK_MAPPING = {
......
...@@ -16,7 +16,7 @@ from typing import Dict, Optional, Tuple ...@@ -16,7 +16,7 @@ from typing import Dict, Optional, Tuple
import flax import flax
import jax.numpy as jnp import jax.numpy as jnp
from .file_utils import ModelOutput from .utils import ModelOutput
@flax.struct.dataclass @flax.struct.dataclass
......
...@@ -30,7 +30,9 @@ from requests import HTTPError ...@@ -30,7 +30,9 @@ from requests import HTTPError
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .file_utils import ( from .generation_flax_utils import FlaxGenerationMixin
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import (
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
EntryNotFoundError, EntryNotFoundError,
...@@ -45,11 +47,9 @@ from .file_utils import ( ...@@ -45,11 +47,9 @@ from .file_utils import (
hf_bucket_url, hf_bucket_url,
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
logging,
replace_return_docstrings, replace_return_docstrings,
) )
from .generation_flax_utils import FlaxGenerationMixin
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -17,7 +17,7 @@ from typing import Optional, Tuple ...@@ -17,7 +17,7 @@ from typing import Optional, Tuple
import torch import torch
from .file_utils import ModelOutput from .utils import ModelOutput
@dataclass @dataclass
......
...@@ -17,7 +17,7 @@ from typing import List, Optional, Tuple ...@@ -17,7 +17,7 @@ from typing import List, Optional, Tuple
import tensorflow as tf import tensorflow as tf
from .file_utils import ModelOutput from .utils import ModelOutput
@dataclass @dataclass
......
...@@ -21,8 +21,7 @@ import re ...@@ -21,8 +21,7 @@ import re
import numpy import numpy
from .file_utils import ExplicitEnum from .utils import ExplicitEnum, logging
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -37,7 +37,11 @@ from requests import HTTPError ...@@ -37,7 +37,11 @@ from requests import HTTPError
from .activations_tf import get_tf_activation from .activations_tf import get_tf_activation
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .file_utils import ( from .generation_tf_utils import TFGenerationMixin
from .modeling_tf_outputs import TFSeq2SeqLMOutput
from .tf_utils import shape_list
from .tokenization_utils_base import BatchEncoding
from .utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
TF2_WEIGHTS_NAME, TF2_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
...@@ -52,12 +56,8 @@ from .file_utils import ( ...@@ -52,12 +56,8 @@ from .file_utils import (
hf_bucket_url, hf_bucket_url,
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
logging,
) )
from .generation_tf_utils import TFGenerationMixin
from .modeling_tf_outputs import TFSeq2SeqLMOutput
from .tf_utils import shape_list
from .tokenization_utils_base import BatchEncoding
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -32,7 +32,8 @@ from .activations import get_activation ...@@ -32,7 +32,8 @@ from .activations import get_activation
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .file_utils import ( from .generation_utils import GenerationMixin
from .utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
TF2_WEIGHTS_NAME, TF2_WEIGHTS_NAME,
...@@ -49,10 +50,9 @@ from .file_utils import ( ...@@ -49,10 +50,9 @@ from .file_utils import (
hf_bucket_url, hf_bucket_url,
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
logging,
replace_return_docstrings, replace_return_docstrings,
) )
from .generation_utils import GenerationMixin
from .utils import logging
from .utils.versions import require_version_core from .utils.versions import require_version_core
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import ( from ...utils import (
_LazyModule, _LazyModule,
is_flax_available, is_flax_available,
is_sentencepiece_available, is_sentencepiece_available,
......
...@@ -25,13 +25,6 @@ from torch import nn ...@@ -25,13 +25,6 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
...@@ -47,7 +40,14 @@ from ...modeling_utils import ( ...@@ -47,7 +40,14 @@ from ...modeling_utils import (
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_linear_layer, prune_linear_layer,
) )
from ...utils import logging from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_albert import AlbertConfig from .configuration_albert import AlbertConfig
......
...@@ -25,7 +25,6 @@ from flax.core.frozen_dict import FrozenDict ...@@ -25,7 +25,6 @@ from flax.core.frozen_dict import FrozenDict
from flax.linen.attention import dot_product_attention_weights from flax.linen.attention import dot_product_attention_weights
from jax import lax from jax import lax
from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import ( from ...modeling_flax_outputs import (
FlaxBaseModelOutput, FlaxBaseModelOutput,
FlaxBaseModelOutputWithPooling, FlaxBaseModelOutputWithPooling,
...@@ -42,7 +41,7 @@ from ...modeling_flax_utils import ( ...@@ -42,7 +41,7 @@ from ...modeling_flax_utils import (
append_replace_return_docstrings, append_replace_return_docstrings,
overwrite_call_docstring, overwrite_call_docstring,
) )
from ...utils import logging from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_albert import AlbertConfig from .configuration_albert import AlbertConfig
......
...@@ -23,14 +23,6 @@ import numpy as np ...@@ -23,14 +23,6 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...file_utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPooling, TFBaseModelOutputWithPooling,
...@@ -53,7 +45,15 @@ from ...modeling_tf_utils import ( ...@@ -53,7 +45,15 @@ from ...modeling_tf_utils import (
unpack_inputs, unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import logging from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_albert import AlbertConfig from .configuration_albert import AlbertConfig
......
...@@ -19,10 +19,9 @@ import os ...@@ -19,10 +19,9 @@ import os
from shutil import copyfile from shutil import copyfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ...file_utils import is_sentencepiece_available
from ...tokenization_utils import AddedToken from ...tokenization_utils import AddedToken
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging from ...utils import is_sentencepiece_available, logging
if is_sentencepiece_available(): if is_sentencepiece_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment