Unverified Commit 69220e0c authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add ToPureTensor transform (#7823)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 3554d80e
......@@ -237,6 +237,7 @@ Conversion
v2.ConvertImageDtype
v2.ToDtype
v2.ConvertBoundingBoxFormat
v2.ToPureTensor
Auto-Augmentation
-----------------
......
......@@ -68,6 +68,9 @@ class ClassificationPresetTrain:
if random_erase_prob > 0:
transforms.append(T.RandomErasing(p=random_erase_prob))
if use_v2:
transforms.append(T.ToPureTensor())
self.transforms = T.Compose(transforms)
def __call__(self, img):
......@@ -107,6 +110,9 @@ class ClassificationPresetEval:
T.Normalize(mean=mean, std=std),
]
if use_v2:
transforms.append(T.ToPureTensor())
self.transforms = T.Compose(transforms)
def __call__(self, img):
......
......@@ -79,6 +79,7 @@ class DetectionPresetTrain:
transforms += [
T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY),
T.SanitizeBoundingBoxes(),
T.ToPureTensor(),
]
self.transforms = T.Compose(transforms)
......@@ -103,6 +104,10 @@ class DetectionPresetEval:
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
transforms += [T.ConvertImageDtype(torch.float)]
if use_v2:
transforms += [T.ToPureTensor()]
self.transforms = T.Compose(transforms)
def __call__(self, img, target):
......
......@@ -63,6 +63,8 @@ class SegmentationPresetTrain:
transforms += [T.ConvertImageDtype(torch.float)]
transforms += [T.Normalize(mean=mean, std=std)]
if use_v2:
transforms += [T.ToPureTensor()]
self.transforms = T.Compose(transforms)
......@@ -98,6 +100,9 @@ class SegmentationPresetEval:
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
if use_v2:
transforms += [T.ToPureTensor()]
self.transforms = T.Compose(transforms)
def __call__(self, img, target):
......
......@@ -2353,3 +2353,24 @@ class TestElastic:
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, make_input, size, device):
check_transform(transforms.ElasticTransform, make_input(size, device=device))
class TestToPureTensor:
def test_correctness(self):
input = {
"img": make_image(),
"img_tensor": make_image_tensor(),
"img_pil": make_image_pil(),
"mask": make_detection_mask(),
"video": make_video(),
"bbox": make_bounding_box(),
"str": "str",
}
out = transforms.ToPureTensor()(input)
for input_value, out_value in zip(input.values(), out.values()):
if isinstance(input_value, datapoints.Datapoint):
assert isinstance(out_value, torch.Tensor) and not isinstance(out_value, datapoints.Datapoint)
else:
assert isinstance(out_value, type(input_value))
......@@ -52,7 +52,7 @@ from ._misc import (
ToDtype,
)
from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImage, ToPILImage
from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
from ._deprecated import ToTensor # usort: skip
......
......@@ -75,3 +75,17 @@ class ToPILImage(Transform):
self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any]
) -> PIL.Image.Image:
return F.to_pil_image(inpt, mode=self.mode)
class ToPureTensor(Transform):
"""[BETA] Convert all datapoints to pure tensors, removing associated metadata (if any).
.. v2betastatus:: ToPureTensor transform
This doesn't scale or change the values, only the type.
"""
_transformed_types = (datapoints.Datapoint,)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor:
return inpt.as_subclass(torch.Tensor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment