"vscode:/vscode.git/clone" did not exist on "68b55885edb1a24f4f48d0d1d947048f74234e06"
Unverified Commit a04ebc8b authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

`Pix2StructImageProcessor` requires `torch>=1.11.0` (#24270)



* fix

* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 8978b696
......@@ -192,7 +192,7 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
mel = torch.tensor(mel[None, None, :])
mel_shrink = torch.nn.functional.interpolate(
mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False, antialias=False
mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False
)
mel_shrink = mel_shrink[0][0].numpy()
mel_fusion = np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0)
......
......@@ -43,11 +43,23 @@ if is_vision_available():
if is_torch_available():
import torch
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
else:
is_torch_greater_or_equal_than_1_11 = False
logger = logging.get_logger(__name__)
DEFAULT_FONT_PATH = "ybelkada/fonts"
def _check_torch_version():
if is_torch_available() and not is_torch_greater_or_equal_than_1_11:
raise ImportError(
f"You are using torch=={torch.__version__}, but torch>=1.11.0 is required to use "
"Pix2StructImageProcessor. Please upgrade torch."
)
# adapted from: https://discuss.pytorch.org/t/tf-image-extract-patches-in-pytorch/171409/2
def torch_extract_patches(image_tensor, patch_height, patch_width):
"""
......@@ -63,6 +75,7 @@ def torch_extract_patches(image_tensor, patch_height, patch_width):
The width of the patches to extract.
"""
requires_backends(torch_extract_patches, ["torch"])
_check_torch_version()
image_tensor = image_tensor.unsqueeze(0)
patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width))
......@@ -240,6 +253,7 @@ class Pix2StructImageProcessor(BaseImageProcessor):
A sequence of `max_patches` flattened patches.
"""
requires_backends(self.extract_flattened_patches, "torch")
_check_torch_version()
# convert to torch
image = to_channel_dimension_format(image, ChannelDimension.FIRST)
......
......@@ -30,6 +30,7 @@ parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
is_torch_greater_or_equal_than_1_11 = parsed_torch_version_base >= version.parse("1.11")
is_torch_greater_or_equal_than_1_10 = parsed_torch_version_base >= version.parse("1.10")
is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11")
......
......@@ -28,6 +28,10 @@ from ...test_image_processing_common import ImageProcessingSavingTestMixin, prep
if is_torch_available():
import torch
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
else:
is_torch_greater_or_equal_than_1_11 = False
if is_vision_available():
from PIL import Image
......@@ -70,6 +74,10 @@ class Pix2StructImageProcessingTester(unittest.TestCase):
return raw_image
@unittest.skipIf(
not is_torch_greater_or_equal_than_1_11,
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
)
@require_torch
@require_vision
class Pix2StructImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase):
......@@ -237,6 +245,10 @@ class Pix2StructImageProcessingTest(ImageProcessingSavingTestMixin, unittest.Tes
)
@unittest.skipIf(
not is_torch_greater_or_equal_than_1_11,
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
)
@require_torch
@require_vision
class Pix2StructImageProcessingTestFourChannels(ImageProcessingSavingTestMixin, unittest.TestCase):
......
......@@ -48,6 +48,9 @@ if is_torch_available():
Pix2StructVisionModel,
)
from transformers.models.pix2struct.modeling_pix2struct import PIX2STRUCT_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
else:
is_torch_greater_or_equal_than_1_11 = False
if is_vision_available():
......@@ -697,6 +700,10 @@ def prepare_img():
return im
@unittest.skipIf(
not is_torch_greater_or_equal_than_1_11,
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
)
@require_vision
@require_torch
@slow
......
......@@ -19,9 +19,14 @@ import numpy as np
import pytest
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.utils import is_torch_available, is_vision_available
if is_torch_available():
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_11
else:
is_torch_greater_or_equal_than_1_11 = False
if is_vision_available():
from PIL import Image
......@@ -34,6 +39,10 @@ if is_vision_available():
)
@unittest.skipIf(
not is_torch_greater_or_equal_than_1_11,
reason="`Pix2StructImageProcessor` requires `torch>=1.11.0`.",
)
@require_vision
@require_torch
class Pix2StructProcessorTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment