Unverified Commit 633e5e89 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[Refactor] Relative imports wherever we can (#21880)

* initial commit

* update

* second batch

* style

* fix imports

* fix relative import on pipeline
parent 43299c63
...@@ -24,9 +24,8 @@ import torch.utils.checkpoint ...@@ -24,9 +24,8 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers.deepspeed import is_deepspeed_zero3_enabled
from ...activations import ACT2FN from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div from ...pytorch_utils import torch_int_div
......
...@@ -25,9 +25,8 @@ import torch.utils.checkpoint ...@@ -25,9 +25,8 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn import CrossEntropyLoss, LayerNorm
from transformers.deepspeed import is_deepspeed_zero3_enabled
from ...activations import ACT2FN from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import softmax_backward_data, torch_int_div from ...pytorch_utils import softmax_backward_data, torch_int_div
......
...@@ -18,12 +18,10 @@ from typing import Optional, Union ...@@ -18,12 +18,10 @@ from typing import Optional, Union
import numpy as np import numpy as np
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_transforms import get_image_size, pad, rescale, to_channel_dimension_format from ...image_transforms import get_image_size, pad, rescale, to_channel_dimension_format
from ...image_utils import ChannelDimension, ImageInput, make_list_of_images, to_numpy_array, valid_images from ...image_utils import ChannelDimension, ImageInput, make_list_of_images, to_numpy_array, valid_images
from ...utils import logging from ...utils import TensorType, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -20,8 +20,8 @@ from typing import List, Optional, Union ...@@ -20,8 +20,8 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
from numpy.fft import fft from numpy.fft import fft
from transformers.feature_extraction_sequence_utils import BatchFeature, SequenceFeatureExtractor from ...feature_extraction_sequence_utils import BatchFeature, SequenceFeatureExtractor
from transformers.utils import TensorType, logging from ...utils import TensorType, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -17,8 +17,6 @@ from typing import Dict, List, Optional, Union ...@@ -17,8 +17,6 @@ from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import ( from ...image_transforms import (
center_crop, center_crop,
...@@ -38,7 +36,7 @@ from ...image_utils import ( ...@@ -38,7 +36,7 @@ from ...image_utils import (
to_numpy_array, to_numpy_array,
valid_images, valid_images,
) )
from ...utils import logging from ...utils import TensorType, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -20,8 +20,7 @@ import torch ...@@ -20,8 +20,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers import AutoBackbone from ... import AutoBackbone
from ...modeling_outputs import SemanticSegmenterOutput from ...modeling_outputs import SemanticSegmenterOutput
from ...modeling_utils import BackboneMixin, PreTrainedModel from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
......
...@@ -18,9 +18,6 @@ from typing import Dict, List, Optional, Union ...@@ -18,9 +18,6 @@ from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import ( from ...image_transforms import (
center_crop, center_crop,
...@@ -40,7 +37,7 @@ from ...image_utils import ( ...@@ -40,7 +37,7 @@ from ...image_utils import (
to_numpy_array, to_numpy_array,
valid_images, valid_images,
) )
from ...utils import logging from ...utils import TensorType, is_vision_available, logging
if is_vision_available(): if is_vision_available():
......
...@@ -19,9 +19,6 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union ...@@ -19,9 +19,6 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import PaddingMode, normalize, pad, rescale, resize, to_channel_dimension_format from ...image_transforms import PaddingMode, normalize, pad, rescale, resize, to_channel_dimension_format
from ...image_utils import ( from ...image_utils import (
...@@ -36,7 +33,7 @@ from ...image_utils import ( ...@@ -36,7 +33,7 @@ from ...image_utils import (
to_numpy_array, to_numpy_array,
valid_images, valid_images,
) )
from ...utils import logging from ...utils import TensorType, is_vision_available, logging
if is_vision_available(): if is_vision_available():
......
...@@ -18,8 +18,6 @@ from typing import Dict, List, Optional, Union ...@@ -18,8 +18,6 @@ from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format
from ...image_utils import ( from ...image_utils import (
...@@ -32,7 +30,7 @@ from ...image_utils import ( ...@@ -32,7 +30,7 @@ from ...image_utils import (
to_numpy_array, to_numpy_array,
valid_images, valid_images,
) )
from ...utils import logging from ...utils import TensorType, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -18,8 +18,6 @@ from typing import Dict, List, Optional, Union ...@@ -18,8 +18,6 @@ from typing import Dict, List, Optional, Union
import numpy as np import numpy as np
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import ( from ...image_transforms import (
center_crop, center_crop,
...@@ -40,8 +38,7 @@ from ...image_utils import ( ...@@ -40,8 +38,7 @@ from ...image_utils import (
to_numpy_array, to_numpy_array,
valid_images, valid_images,
) )
from ...utils import logging from ...utils import TensorType, is_vision_available, logging
from ...utils.import_utils import is_vision_available
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -20,9 +20,9 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Un ...@@ -20,9 +20,9 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Un
import numpy as np import numpy as np
from transformers.feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict from ...image_processing_utils import BaseImageProcessor, get_size_dict
from transformers.image_transforms import ( from ...image_transforms import (
PaddingMode, PaddingMode,
center_to_corners_format, center_to_corners_format,
corners_to_center_format, corners_to_center_format,
...@@ -34,7 +34,7 @@ from transformers.image_transforms import ( ...@@ -34,7 +34,7 @@ from transformers.image_transforms import (
rgb_to_id, rgb_to_id,
to_channel_dimension_format, to_channel_dimension_format,
) )
from transformers.image_utils import ( from ...image_utils import (
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_STD,
ChannelDimension, ChannelDimension,
...@@ -48,7 +48,9 @@ from transformers.image_utils import ( ...@@ -48,7 +48,9 @@ from transformers.image_utils import (
valid_coco_panoptic_annotations, valid_coco_panoptic_annotations,
valid_images, valid_images,
) )
from transformers.utils import ( from ...utils import (
ExplicitEnum,
TensorType,
is_flax_available, is_flax_available,
is_jax_tensor, is_jax_tensor,
is_scipy_available, is_scipy_available,
...@@ -57,8 +59,8 @@ from transformers.utils import ( ...@@ -57,8 +59,8 @@ from transformers.utils import (
is_torch_available, is_torch_available,
is_torch_tensor, is_torch_tensor,
is_vision_available, is_vision_available,
logging,
) )
from transformers.utils.generic import ExplicitEnum, TensorType
if is_torch_available(): if is_torch_available():
...@@ -74,6 +76,7 @@ if is_scipy_available(): ...@@ -74,6 +76,7 @@ if is_scipy_available():
import scipy.special import scipy.special
import scipy.stats import scipy.stats
logger = logging.get_logger(__name__)
AnnotationType = Dict[str, Union[int, str, List[Dict]]] AnnotationType = Dict[str, Union[int, str, List[Dict]]]
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import torch import torch
from torch.utils.data import Dataset, IterableDataset from torch.utils.data import Dataset, IterableDataset
from transformers.utils.generic import ModelOutput from ..utils.generic import ModelOutput
class PipelineDataset(Dataset): class PipelineDataset(Dataset):
......
import enum import enum
import warnings import warnings
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING from .. import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING
from ..utils import add_end_docstrings, is_tf_available from ..utils import add_end_docstrings, is_tf_available
from .base import PIPELINE_INIT_ARGS, Pipeline from .base import PIPELINE_INIT_ARGS, Pipeline
......
...@@ -20,8 +20,8 @@ from dataclasses import dataclass, field ...@@ -20,8 +20,8 @@ from dataclasses import dataclass, field
import torch import torch
from transformers.training_args import TrainingArguments from ..training_args import TrainingArguments
from transformers.utils import cached_property, is_sagemaker_dp_enabled, logging from ..utils import cached_property, is_sagemaker_dp_enabled, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
from copy import deepcopy from copy import deepcopy
from transformers.utils import is_accelerate_available, is_bitsandbytes_available from .import_utils import is_accelerate_available, is_bitsandbytes_available
if is_bitsandbytes_available(): if is_bitsandbytes_available():
......
...@@ -49,8 +49,6 @@ from huggingface_hub.utils import ( ...@@ -49,8 +49,6 @@ from huggingface_hub.utils import (
) )
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers.utils.logging import tqdm
from . import __version__, logging from . import __version__, logging
from .generic import working_or_temp_dir from .generic import working_or_temp_dir
from .import_utils import ( from .import_utils import (
...@@ -61,6 +59,7 @@ from .import_utils import ( ...@@ -61,6 +59,7 @@ from .import_utils import (
is_torch_available, is_torch_available,
is_training_run_on_sagemaker, is_training_run_on_sagemaker,
) )
from .logging import tqdm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -29,9 +29,8 @@ from typing import Any ...@@ -29,9 +29,8 @@ from typing import Any
from packaging import version from packaging import version
from transformers.utils.versions import importlib_metadata
from . import logging from . import logging
from .versions import importlib_metadata
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment