Unverified Commit fdca3073 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added CPU/CUDA and batch input for dtype conversion op (#2755)



* make convert_image_dtype scriptable

* move convert dtype to functional_tensor since only works on tensors

* retain availability of convert_image_dtype in functional.py

* Update code and tests

* Replaced int by torch.dtype

* int -> torch.dtype and use F instead of F_t

* Update functional_tensor.py

* Added CPU/CUDA+batch tests

* Fixed tests according to review
Co-authored-by: default avatarBrian <nairbv@yahoo.com>
parent 3d0c7794
......@@ -369,3 +369,16 @@ class TransformsTester(unittest.TestCase):
err < tol,
msg="{}: err={}, tol={}: \n{}\nvs\n{}".format(msg, err, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
)
def cycle_over(objs):
for idx, obj in enumerate(objs):
yield obj, objs[:idx] + objs[idx + 1:]
def int_dtypes():
return torch.testing.integral_types()
def float_dtypes():
return torch.testing.floating_types()
......@@ -20,24 +20,11 @@ try:
except ImportError:
stats = None
GRACE_HOPPER = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
def cycle_over(objs):
objs = list(objs)
for idx, obj in enumerate(objs):
yield obj, objs[:idx] + objs[idx + 1:]
from common_utils import cycle_over, int_dtypes, float_dtypes
def int_dtypes():
yield from iter(
(torch.uint8, torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,)
)
def float_dtypes():
yield from iter((torch.float32, torch.float, torch.float64, torch.double))
GRACE_HOPPER = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
class Tester(unittest.TestCase):
......
......@@ -9,7 +9,7 @@ import numpy as np
import unittest
from common_utils import TransformsTester, get_tmp_dir
from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes
class Tester(TransformsTester):
......@@ -27,14 +27,14 @@ class Tester(TransformsTester):
transformed_pil_img = f(pil_img, **fn_kwargs)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
def _test_transform_vs_scripted(self, transform, s_transform, tensor):
def _test_transform_vs_scripted(self, transform, s_transform, tensor, msg=None):
torch.manual_seed(12)
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2))
self.assertTrue(out1.equal(out2), msg=msg)
def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors):
def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors, msg=None):
torch.manual_seed(12)
transformed_batch = transform(batch_tensors)
......@@ -42,11 +42,11 @@ class Tester(TransformsTester):
img_tensor = batch_tensors[i, ...]
torch.manual_seed(12)
transformed_img = transform(img_tensor)
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]))
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]), msg=msg)
torch.manual_seed(12)
s_transformed_batch = s_transform(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch))
self.assertTrue(transformed_batch.equal(s_transformed_batch), msg=msg)
def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
if meth_kwargs is None:
......@@ -492,6 +492,32 @@ class Tester(TransformsTester):
self._test_transform_vs_scripted(fn, scripted_fn, tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)
def test_convert_image_dtype(self):
tensor, _ = self._create_data(26, 34, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
for in_dtype in int_dtypes() + float_dtypes():
in_tensor = tensor.to(in_dtype)
in_batch_tensors = batch_tensors.to(in_dtype)
for out_dtype in int_dtypes() + float_dtypes():
fn = T.ConvertImageDtype(dtype=out_dtype)
scripted_fn = torch.jit.script(fn)
if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or \
(in_dtype == torch.float64 and out_dtype == torch.int64):
with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
self._test_transform_vs_scripted(fn, scripted_fn, in_tensor)
with self.assertRaisesRegex(RuntimeError, r"cannot be performed safely"):
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
continue
self._test_transform_vs_scripted(fn, scripted_fn, in_tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, in_batch_tensors)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt"))
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment