Unverified Commit 633e5e89 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[Refactor] Relative imports wherever we can (#21880)

* initial commit

* update

* second batch

* style

* fix imports

* fix relative import on pipeline
parent 43299c63
......@@ -24,9 +24,8 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.deepspeed import is_deepspeed_zero3_enabled
from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_int_div
......
......@@ -25,9 +25,8 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from transformers.deepspeed import is_deepspeed_zero3_enabled
from ...activations import ACT2FN
from ...deepspeed import is_deepspeed_zero3_enabled
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import softmax_backward_data, torch_int_div
......
......@@ -18,12 +18,10 @@ from typing import Optional, Union
import numpy as np
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_transforms import get_image_size, pad, rescale, to_channel_dimension_format
from ...image_utils import ChannelDimension, ImageInput, make_list_of_images, to_numpy_array, valid_images
from ...utils import logging
from ...utils import TensorType, logging
logger = logging.get_logger(__name__)
......
......@@ -20,8 +20,8 @@ from typing import List, Optional, Union
import numpy as np
from numpy.fft import fft
from transformers.feature_extraction_sequence_utils import BatchFeature, SequenceFeatureExtractor
from transformers.utils import TensorType, logging
from ...feature_extraction_sequence_utils import BatchFeature, SequenceFeatureExtractor
from ...utils import TensorType, logging
logger = logging.get_logger(__name__)
......
......@@ -17,8 +17,6 @@ from typing import Dict, List, Optional, Union
import numpy as np
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
center_crop,
......@@ -38,7 +36,7 @@ from ...image_utils import (
to_numpy_array,
valid_images,
)
from ...utils import logging
from ...utils import TensorType, logging
logger = logging.get_logger(__name__)
......
......@@ -20,8 +20,7 @@ import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import AutoBackbone
from ... import AutoBackbone
from ...modeling_outputs import SemanticSegmenterOutput
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
......
......@@ -18,9 +18,6 @@ from typing import Dict, List, Optional, Union
import numpy as np
from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
center_crop,
......@@ -40,7 +37,7 @@ from ...image_utils import (
to_numpy_array,
valid_images,
)
from ...utils import logging
from ...utils import TensorType, is_vision_available, logging
if is_vision_available():
......
......@@ -19,9 +19,6 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from transformers.utils import is_vision_available
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import PaddingMode, normalize, pad, rescale, resize, to_channel_dimension_format
from ...image_utils import (
......@@ -36,7 +33,7 @@ from ...image_utils import (
to_numpy_array,
valid_images,
)
from ...utils import logging
from ...utils import TensorType, is_vision_available, logging
if is_vision_available():
......
......@@ -18,8 +18,6 @@ from typing import Dict, List, Optional, Union
import numpy as np
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import normalize, rescale, resize, to_channel_dimension_format
from ...image_utils import (
......@@ -32,7 +30,7 @@ from ...image_utils import (
to_numpy_array,
valid_images,
)
from ...utils import logging
from ...utils import TensorType, logging
logger = logging.get_logger(__name__)
......
......@@ -18,8 +18,6 @@ from typing import Dict, List, Optional, Union
import numpy as np
from transformers.utils.generic import TensorType
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
center_crop,
......@@ -40,8 +38,7 @@ from ...image_utils import (
to_numpy_array,
valid_images,
)
from ...utils import logging
from ...utils.import_utils import is_vision_available
from ...utils import TensorType, is_vision_available, logging
logger = logging.get_logger(__name__)
......
......@@ -20,9 +20,9 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Un
import numpy as np
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict
from transformers.image_transforms import (
from ...feature_extraction_utils import BatchFeature
from ...image_processing_utils import BaseImageProcessor, get_size_dict
from ...image_transforms import (
PaddingMode,
center_to_corners_format,
corners_to_center_format,
......@@ -34,7 +34,7 @@ from transformers.image_transforms import (
rgb_to_id,
to_channel_dimension_format,
)
from transformers.image_utils import (
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
ChannelDimension,
......@@ -48,7 +48,9 @@ from transformers.image_utils import (
valid_coco_panoptic_annotations,
valid_images,
)
from transformers.utils import (
from ...utils import (
ExplicitEnum,
TensorType,
is_flax_available,
is_jax_tensor,
is_scipy_available,
......@@ -57,8 +59,8 @@ from transformers.utils import (
is_torch_available,
is_torch_tensor,
is_vision_available,
logging,
)
from transformers.utils.generic import ExplicitEnum, TensorType
if is_torch_available():
......@@ -74,6 +76,7 @@ if is_scipy_available():
import scipy.special
import scipy.stats
logger = logging.get_logger(__name__)
AnnotationType = Dict[str, Union[int, str, List[Dict]]]
......
......@@ -2,7 +2,7 @@ import numpy as np
import torch
from torch.utils.data import Dataset, IterableDataset
from transformers.utils.generic import ModelOutput
from ..utils.generic import ModelOutput
class PipelineDataset(Dataset):
......
import enum
import warnings
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING
from .. import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING
from ..utils import add_end_docstrings, is_tf_available
from .base import PIPELINE_INIT_ARGS, Pipeline
......
......@@ -20,8 +20,8 @@ from dataclasses import dataclass, field
import torch
from transformers.training_args import TrainingArguments
from transformers.utils import cached_property, is_sagemaker_dp_enabled, logging
from ..training_args import TrainingArguments
from ..utils import cached_property, is_sagemaker_dp_enabled, logging
logger = logging.get_logger(__name__)
......
from copy import deepcopy
from transformers.utils import is_accelerate_available, is_bitsandbytes_available
from .import_utils import is_accelerate_available, is_bitsandbytes_available
if is_bitsandbytes_available():
......
......@@ -49,8 +49,6 @@ from huggingface_hub.utils import (
)
from requests.exceptions import HTTPError
from transformers.utils.logging import tqdm
from . import __version__, logging
from .generic import working_or_temp_dir
from .import_utils import (
......@@ -61,6 +59,7 @@ from .import_utils import (
is_torch_available,
is_training_run_on_sagemaker,
)
from .logging import tqdm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -29,9 +29,8 @@ from typing import Any
from packaging import version
from transformers.utils.versions import importlib_metadata
from . import logging
from .versions import importlib_metadata
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment