Unverified Commit a04ebc8b authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

`Pix2StructImageProcessor` requires `torch>=1.11.0` (#24270)



* fix

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 8978b696
...@@ -192,7 +192,7 @@ class ClapFeatureExtractor(SequenceFeatureExtractor): ...@@ -192,7 +192,7 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
mel = torch.tensor(mel[None, None, :]) mel = torch.tensor(mel[None, None, :])
mel_shrink = torch.nn.functional.interpolate( mel_shrink = torch.nn.functional.interpolate(
mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False, antialias=False mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False
) )
mel_shrink = mel_shrink[0][0].numpy() mel_shrink = mel_shrink[0][0].numpy()
mel_fusion = np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0) mel_fusion = np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0)
......
...@@ -43,11 +43,23 @@ if is_vision_available(): ...@@ -43,11 +43,23 @@ if is_vision_available():
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
else:
is_torch_greater_or_equal_than_1_11 = False
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
DEFAULT_FONT_PATH = "ybelkada/fonts" DEFAULT_FONT_PATH = "ybelkada/fonts"
def _check_torch_version():
if is_torch_available() and not is_torch_greater_or_equal_than_1_11:
raise ImportError(
f"You are using torch=={torch.__version__}, but torch>=1.11.0 is required to use "
"Pix2StructImageProcessor. Please upgrade torch."
)
# adapted from: https://discuss.pytorch.org/t/tf-image-extract-patches-in-pytorch/171409/2 # adapted from: https://discuss.pytorch.org/t/tf-image-extract-patches-in-pytorch/171409/2
def torch_extract_patches(image_tensor, patch_height, patch_width): def torch_extract_patches(image_tensor, patch_height, patch_width):
""" """
...@@ -63,6 +75,7 @@ def torch_extract_patches(image_tensor, patch_height, patch_width): ...@@ -63,6 +75,7 @@ def torch_extract_patches(image_tensor, patch_height, patch_width):
The width of the patches to extract. The width of the patches to extract.
""" """
requires_backends(torch_extract_patches, ["torch"]) requires_backends(torch_extract_patches, ["torch"])
_check_torch_version()
image_tensor = image_tensor.unsqueeze(0) image_tensor = image_tensor.unsqueeze(0)
patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width)) patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width))
...@@ -240,6 +253,7 @@ class Pix2StructImageProcessor(BaseImageProcessor): ...@@ -240,6 +253,7 @@ class Pix2StructImageProcessor(BaseImageProcessor):
A sequence of `max_patches` flattened patches. A sequence of `max_patches` flattened patches.
""" """
requires_backends(self.extract_flattened_patches, "torch") requires_backends(self.extract_flattened_patches, "torch")
_check_torch_version()
# convert to torch # convert to torch
image = to_channel_dimension_format(image, ChannelDimension.FIRST) image = to_channel_dimension_format(image, ChannelDimension.FIRST)
......
...@@ -30,6 +30,7 @@ parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_ ...@@ -30,6 +30,7 @@ parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0") is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12") is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
is_torch_greater_or_equal_than_1_11 = parsed_torch_version_base >= version.parse("1.11")
is_torch_greater_or_equal_than_1_10 = parsed_torch_version_base >= version.parse("1.10") is_torch_greater_or_equal_than_1_10 = parsed_torch_version_base >= version.parse("1.10")
is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11") is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11")
......
...@@ -28,6 +28,10 @@ from ...test_image_processing_common import ImageProcessingSavingTestMixin, prep ...@@ -28,6 +28,10 @@ from ...test_image_processing_common import ImageProcessingSavingTestMixin, prep
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
else:
is_torch_greater_or_equal_than_1_11 = False
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
...@@ -70,6 +74,10 @@ class Pix2StructImageProcessingTester(unittest.TestCase): ...@@ -70,6 +74,10 @@ class Pix2StructImageProcessingTester(unittest.TestCase):
return raw_image return raw_image
@unittest.skipIf(
not is_torch_greater_or_equal_than_1_11,
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
)
@require_torch @require_torch
@require_vision @require_vision
class Pix2StructImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase): class Pix2StructImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase):
...@@ -237,6 +245,10 @@ class Pix2StructImageProcessingTest(ImageProcessingSavingTestMixin, unittest.Tes ...@@ -237,6 +245,10 @@ class Pix2StructImageProcessingTest(ImageProcessingSavingTestMixin, unittest.Tes
) )
@unittest.skipIf(
not is_torch_greater_or_equal_than_1_11,
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
)
@require_torch @require_torch
@require_vision @require_vision
class Pix2StructImageProcessingTestFourChannels(ImageProcessingSavingTestMixin, unittest.TestCase): class Pix2StructImageProcessingTestFourChannels(ImageProcessingSavingTestMixin, unittest.TestCase):
......
...@@ -48,6 +48,9 @@ if is_torch_available(): ...@@ -48,6 +48,9 @@ if is_torch_available():
Pix2StructVisionModel, Pix2StructVisionModel,
) )
from transformers.models.pix2struct.modeling_pix2struct import PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.pix2struct.modeling_pix2struct import PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
else:
is_torch_greater_or_equal_than_1_11 = False
if is_vision_available(): if is_vision_available():
...@@ -697,6 +700,10 @@ def prepare_img(): ...@@ -697,6 +700,10 @@ def prepare_img():
return im return im
@unittest.skipIf(
not is_torch_greater_or_equal_than_1_11,
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
)
@require_vision @require_vision
@require_torch @require_torch
@slow @slow
......
...@@ -19,9 +19,14 @@ import numpy as np ...@@ -19,9 +19,14 @@ import numpy as np
import pytest import pytest
from transformers.testing_utils import require_torch, require_vision from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available from transformers.utils import is_torch_available, is_vision_available
if is_torch_available():
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
else:
is_torch_greater_or_equal_than_1_11 = False
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
...@@ -34,6 +39,10 @@ if is_vision_available(): ...@@ -34,6 +39,10 @@ if is_vision_available():
) )
@unittest.skipIf(
not is_torch_greater_or_equal_than_1_11,
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
)
@require_vision @require_vision
@require_torch @require_torch
class Pix2StructProcessorTest(unittest.TestCase): class Pix2StructProcessorTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment