Unverified Commit 4975002d authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Reorganize file utils (#16264)

* Split file_utils in several submodules

* Fixes

* Add back more objects

* More fixes

* Who exactly decided to import that from there?

* Second suggestion to code with code review

* Revert wront move

* Fix imports

* Adapt all imports

* Adapt all imports everywhere

* Revert this import, will fix in a separate commit
parent 71356034
......@@ -6,7 +6,7 @@ from typing import Optional
import torch
from .file_utils import add_start_docstrings
from .utils import add_start_docstrings
STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
......
......@@ -19,8 +19,8 @@ from typing import List
import numpy as np
import tensorflow as tf
from .file_utils import add_start_docstrings
from .tf_utils import set_tensor_by_indices_to_value
from .utils import add_start_docstrings
from .utils.logging import get_logger
......
......@@ -21,7 +21,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from .file_utils import ModelOutput
from .generation_tf_logits_process import (
TFLogitsProcessorList,
TFMinLengthLogitsProcessor,
......@@ -33,7 +32,7 @@ from .generation_tf_logits_process import (
TFTopPLogitsWarper,
)
from .tf_utils import set_tensor_by_indices_to_value, shape_list
from .utils import logging
from .utils import ModelOutput, logging
logger = logging.get_logger(__name__)
......
......@@ -23,7 +23,6 @@ import torch
import torch.distributed as dist
from torch import nn
from .file_utils import ModelOutput
from .generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .generation_logits_process import (
......@@ -52,7 +51,7 @@ from .generation_stopping_criteria import (
validate_stopping_criteria,
)
from .pytorch_utils import torch_int_div
from .utils import logging
from .utils import ModelOutput, logging
logger = logging.get_logger(__name__)
......
......@@ -22,7 +22,8 @@ import PIL.ImageOps
import requests
from .file_utils import _is_torch, is_torch_available
from .utils import is_torch_available
from .utils.generic import _is_torch
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
......
......@@ -22,8 +22,7 @@ import sys
import tempfile
from pathlib import Path
from .file_utils import is_datasets_available
from .utils import logging
from .utils import is_datasets_available, logging
logger = logging.get_logger(__name__)
......@@ -44,9 +43,9 @@ if _has_comet:
except (ImportError, ValueError):
_has_comet = False
from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
from .trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
from .utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
# Integration functions:
......
......@@ -12,8 +12,8 @@ from tensorflow.keras.callbacks import Callback
from huggingface_hub import Repository
from . import IntervalStrategy, PreTrainedTokenizerBase
from .file_utils import get_full_repo_name
from .modelcard import TrainingSummary
from .utils import get_full_repo_name
logger = logging.getLogger(__name__)
......
......@@ -28,20 +28,6 @@ import yaml
from huggingface_hub import model_info
from . import __version__
from .file_utils import (
CONFIG_NAME,
MODEL_CARD_NAME,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME,
cached_path,
hf_bucket_url,
is_datasets_available,
is_offline_mode,
is_remote_url,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
from .models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
......@@ -56,7 +42,21 @@ from .models.auto.modeling_auto import (
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
)
from .training_args import ParallelMode
from .utils import logging
from .utils import (
CONFIG_NAME,
MODEL_CARD_NAME,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME,
cached_path,
hf_bucket_url,
is_datasets_available,
is_offline_mode,
is_remote_url,
is_tf_available,
is_tokenizers_available,
is_torch_available,
logging,
)
TASK_MAPPING = {
......
......@@ -16,7 +16,7 @@ from typing import Dict, Optional, Tuple
import flax
import jax.numpy as jnp
from .file_utils import ModelOutput
from .utils import ModelOutput
@flax.struct.dataclass
......
......@@ -30,7 +30,9 @@ from requests import HTTPError
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .file_utils import (
from .generation_flax_utils import FlaxGenerationMixin
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import (
FLAX_WEIGHTS_NAME,
WEIGHTS_NAME,
EntryNotFoundError,
......@@ -45,11 +47,9 @@ from .file_utils import (
hf_bucket_url,
is_offline_mode,
is_remote_url,
logging,
replace_return_docstrings,
)
from .generation_flax_utils import FlaxGenerationMixin
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import logging
logger = logging.get_logger(__name__)
......
......@@ -17,7 +17,7 @@ from typing import Optional, Tuple
import torch
from .file_utils import ModelOutput
from .utils import ModelOutput
@dataclass
......
......@@ -17,7 +17,7 @@ from typing import List, Optional, Tuple
import tensorflow as tf
from .file_utils import ModelOutput
from .utils import ModelOutput
@dataclass
......
......@@ -21,8 +21,7 @@ import re
import numpy
from .file_utils import ExplicitEnum
from .utils import logging
from .utils import ExplicitEnum, logging
logger = logging.get_logger(__name__)
......
......@@ -37,7 +37,11 @@ from requests import HTTPError
from .activations_tf import get_tf_activation
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .file_utils import (
from .generation_tf_utils import TFGenerationMixin
from .modeling_tf_outputs import TFSeq2SeqLMOutput
from .tf_utils import shape_list
from .tokenization_utils_base import BatchEncoding
from .utils import (
DUMMY_INPUTS,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME,
......@@ -52,12 +56,8 @@ from .file_utils import (
hf_bucket_url,
is_offline_mode,
is_remote_url,
logging,
)
from .generation_tf_utils import TFGenerationMixin
from .modeling_tf_outputs import TFSeq2SeqLMOutput
from .tf_utils import shape_list
from .tokenization_utils_base import BatchEncoding
from .utils import logging
logger = logging.get_logger(__name__)
......
......@@ -32,7 +32,8 @@ from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from .dynamic_module_utils import custom_object_save
from .file_utils import (
from .generation_utils import GenerationMixin
from .utils import (
DUMMY_INPUTS,
FLAX_WEIGHTS_NAME,
TF2_WEIGHTS_NAME,
......@@ -49,10 +50,9 @@ from .file_utils import (
hf_bucket_url,
is_offline_mode,
is_remote_url,
logging,
replace_return_docstrings,
)
from .generation_utils import GenerationMixin
from .utils import logging
from .utils.versions import require_version_core
......
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import (
from ...utils import (
_LazyModule,
is_flax_available,
is_sentencepiece_available,
......
......@@ -25,13 +25,6 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
......@@ -47,7 +40,14 @@ from ...modeling_utils import (
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...utils import logging
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_albert import AlbertConfig
......
......@@ -25,7 +25,6 @@ from flax.core.frozen_dict import FrozenDict
from flax.linen.attention import dot_product_attention_weights
from jax import lax
from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPooling,
......@@ -42,7 +41,7 @@ from ...modeling_flax_utils import (
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import logging
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_albert import AlbertConfig
......
......@@ -23,14 +23,6 @@ import numpy as np
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...file_utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPooling,
......@@ -53,7 +45,15 @@ from ...modeling_tf_utils import (
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import logging
from ...utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_albert import AlbertConfig
......
......@@ -19,10 +19,9 @@ import os
from shutil import copyfile
from typing import List, Optional, Tuple
from ...file_utils import is_sentencepiece_available
from ...tokenization_utils import AddedToken
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import logging
from ...utils import is_sentencepiece_available, logging
if is_sentencepiece_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment