"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "01a56927f1603f1e89d1e5ada74d2aa75da2d46b"
Unverified Commit aea748b3 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Ported LinearTransformation (#6458)

* WIP

* Fixed dtype correction and tests

* Removed PIL Image support and output always Tensor
parent a409c3fe
...@@ -1487,3 +1487,36 @@ class TestFixedSizeCrop: ...@@ -1487,3 +1487,36 @@ class TestFixedSizeCrop:
transform(bounding_box) transform(bounding_box)
mock.assert_called_once() mock.assert_called_once()
class TestLinearTransformation:
def test_assertions(self):
with pytest.raises(ValueError, match="transformation_matrix should be square"):
transforms.LinearTransformation(torch.rand(2, 3), torch.rand(5))
with pytest.raises(ValueError, match="mean_vector should have the same length"):
transforms.LinearTransformation(torch.rand(3, 3), torch.rand(5))
@pytest.mark.parametrize(
"inpt",
[
122 * torch.ones(1, 3, 8, 8),
122.0 * torch.ones(1, 3, 8, 8),
features.Image(122 * torch.ones(1, 3, 8, 8)),
PIL.Image.new("RGB", (8, 8), (122, 122, 122)),
],
)
def test__transform(self, inpt):
v = 121 * torch.ones(3 * 8 * 8)
m = torch.ones(3 * 8 * 8, 3 * 8 * 8)
transform = transforms.LinearTransformation(m, v)
if isinstance(inpt, PIL.Image.Image):
with pytest.raises(TypeError, match="Unsupported input type"):
transform(inpt)
else:
output = transform(inpt)
assert isinstance(output, torch.Tensor)
assert output.unique() == 3 * 8 * 8
assert output.dtype == inpt.dtype
...@@ -37,7 +37,7 @@ from ._geometry import ( ...@@ -37,7 +37,7 @@ from ._geometry import (
TenCrop, TenCrop,
) )
from ._meta import ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype from ._meta import ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, Normalize, ToDtype from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, ToDtype
from ._type_conversion import DecodeImage, LabelToOneHot, ToImagePIL, ToImageTensor from ._type_conversion import DecodeImage, LabelToOneHot, ToImagePIL, ToImageTensor
from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip
import functools import functools
from typing import Any, Callable, Dict, List, Sequence, Type, Union from typing import Any, Callable, Dict, List, Sequence, Type, Union
import PIL.Image
import torch import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.transforms import _setup_size from torchvision.transforms.transforms import _setup_size
...@@ -32,6 +35,59 @@ class Lambda(Transform): ...@@ -32,6 +35,59 @@ class Lambda(Transform):
return ", ".join(extras) return ", ".join(extras)
class LinearTransformation(Transform):
def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor):
super().__init__()
if transformation_matrix.size(0) != transformation_matrix.size(1):
raise ValueError(
"transformation_matrix should be square. Got "
f"{tuple(transformation_matrix.size())} rectangular matrix."
)
if mean_vector.size(0) != transformation_matrix.size(0):
raise ValueError(
f"mean_vector should have the same length {mean_vector.size(0)}"
f" as any one of the dimensions of the transformation_matrix [{tuple(transformation_matrix.size())}]"
)
if transformation_matrix.device != mean_vector.device:
raise ValueError(
f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}"
)
self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features._Feature) and not isinstance(inpt, features.Image):
return inpt
elif isinstance(inpt, PIL.Image.Image):
raise TypeError("Unsupported input type")
# Image instance after linear transformation is not Image anymore due to unknown data range
# Thus we will return Tensor for input Image
shape = inpt.shape
n = shape[-3] * shape[-2] * shape[-1]
if n != self.transformation_matrix.shape[0]:
raise ValueError(
"Input tensor and transformation matrix have incompatible shape."
+ f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
+ f"{self.transformation_matrix.shape[0]}"
)
if inpt.device.type != self.mean_vector.device.type:
raise ValueError(
"Input tensor should be on the same device as transformation matrix and mean vector. "
f"Got {inpt.device} vs {self.mean_vector.device}"
)
flat_tensor = inpt.view(-1, n) - self.mean_vector
transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix)
return transformed_tensor.view(shape)
class Normalize(Transform): class Normalize(Transform):
def __init__(self, mean: List[float], std: List[float]): def __init__(self, mean: List[float], std: List[float]):
super().__init__() super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment