Unverified Commit ef0f85cd authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[Vision] `.to` function for ImageProcessors (#20536)



* add v1 with tests

* add checker

* simplified version

* update docstring

* better version

* fix docstring + change order

* make style

* tests + change conditions

* final tests

* modify docstring

* Update src/transformers/feature_extraction_utils.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* replace by `ValueError`

* fix logic

* apply suggestions

* `dtype` is not needed

* adapt suggestions

* remove `_parse_args_to_device`
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 67d32f46
......@@ -40,6 +40,7 @@ from .utils import (
is_tf_available,
is_torch_available,
is_torch_device,
is_torch_dtype,
logging,
torch_required,
)
......@@ -47,7 +48,7 @@ from .utils import (
if TYPE_CHECKING:
if is_torch_available():
import torch
import torch # noqa
logger = logging.get_logger(__name__)
......@@ -138,7 +139,7 @@ class BatchFeature(UserDict):
elif tensor_type == TensorType.PYTORCH:
if not is_torch_available():
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
import torch
import torch # noqa
def as_tensor(value):
if isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], np.ndarray):
......@@ -175,25 +176,47 @@ class BatchFeature(UserDict):
return self
@torch_required
# Copied from transformers.tokenization_utils_base.BatchEncoding.to with BatchEncoding->BatchFeature
def to(self, device: Union[str, "torch.device"]) -> "BatchFeature":
def to(self, *args, **kwargs) -> "BatchFeature":
"""
Send all values to device by calling `v.to(device)` (PyTorch only).
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
different `dtypes` and sending the `BatchFeature` to a different `device`.
Args:
device (`str` or `torch.device`): The device to put the tensors on.
args (`Tuple`):
Will be passed to the `to(...)` function of the tensors.
kwargs (`Dict`, *optional*):
Will be passed to the `to(...)` function of the tensors.
Returns:
[`BatchFeature`]: The same instance after modification.
"""
# This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items()}
import torch # noqa
new_data = {}
device = kwargs.get("device")
# Check if the args are a device or a dtype
if device is None and len(args) > 0:
# device should be always the first argument
arg = args[0]
if is_torch_dtype(arg):
# The first argument is a dtype
pass
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
device = arg
else:
# it's something else
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
for k, v in self.items():
# check if v is a floating point
if torch.is_floating_point(v):
# cast and send to device
new_data[k] = v.to(*args, **kwargs)
elif device is not None:
new_data[k] = v.to(device=device)
else:
logger.warning(f"Attempting to cast a BatchFeature to type {str(device)}. This is not supported.")
new_data[k] = v
self.data = new_data
return self
......
......@@ -47,6 +47,7 @@ from .generic import (
is_tensor,
is_tf_tensor,
is_torch_device,
is_torch_dtype,
is_torch_tensor,
reshape,
squeeze,
......
......@@ -123,6 +123,24 @@ def is_torch_device(x):
return False if not is_torch_available() else _is_torch_device(x)
def _is_torch_dtype(x):
import torch
if isinstance(x, str):
if hasattr(torch, x):
x = getattr(torch, x)
else:
return False
return isinstance(x, torch.dtype)
def is_torch_dtype(x):
"""
Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed.
"""
return False if not is_torch_available() else _is_torch_dtype(x)
def _is_tensorflow(x):
import tensorflow as tf
......
......@@ -84,6 +84,7 @@ class DeiTFeatureExtractionTester(unittest.TestCase):
class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
feature_extraction_class = DeiTFeatureExtractor if is_vision_available() else None
test_cast_dtype = True
def setUp(self):
self.feature_extract_tester = DeiTFeatureExtractionTester(self)
......
......@@ -25,7 +25,15 @@ from pathlib import Path
from huggingface_hub import HfFolder, delete_repo, set_access_token
from requests.exceptions import HTTPError
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
from transformers.testing_utils import TOKEN, USER, check_json_file_has_correct_format, get_tests_dir, is_staging_test
from transformers.testing_utils import (
TOKEN,
USER,
check_json_file_has_correct_format,
get_tests_dir,
is_staging_test,
require_torch,
require_vision,
)
from transformers.utils import is_torch_available, is_vision_available
......@@ -134,6 +142,8 @@ def prepare_video_inputs(feature_extract_tester, equal_resolution=False, numpify
class FeatureExtractionSavingTestMixin:
test_cast_dtype = None
def test_feat_extract_to_json_string(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
obj = json.loads(feat_extract.to_json_string())
......@@ -164,6 +174,41 @@ class FeatureExtractionSavingTestMixin:
feat_extract = self.feature_extraction_class()
self.assertIsNotNone(feat_extract)
@require_torch
@require_vision
def test_cast_dtype_device(self):
if self.test_cast_dtype is not None:
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
encoding = feature_extractor(image_inputs, return_tensors="pt")
# for layoutLM compatiblity
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
self.assertEqual(encoding.pixel_values.dtype, torch.float32)
encoding = feature_extractor(image_inputs, return_tensors="pt").to(torch.float16)
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
self.assertEqual(encoding.pixel_values.dtype, torch.float16)
encoding = feature_extractor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16)
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
self.assertEqual(encoding.pixel_values.dtype, torch.bfloat16)
with self.assertRaises(TypeError):
_ = feature_extractor(image_inputs, return_tensors="pt").to(torch.bfloat16, "cpu")
# Try with text + image feature
encoding = feature_extractor(image_inputs, return_tensors="pt")
encoding.update({"input_ids": torch.LongTensor([[1, 2, 3], [4, 5, 6]])})
encoding = encoding.to(torch.float16)
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
self.assertEqual(encoding.pixel_values.dtype, torch.float16)
self.assertEqual(encoding.input_ids.dtype, torch.long)
class FeatureExtractorUtilTester(unittest.TestCase):
def test_cached_files_are_used_when_internet_is_down(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment