Unverified Commit d4575e5b authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Let LinearTransformation return datapoints instead of tensors (#7244)

parent 3a0e028f
......@@ -76,12 +76,7 @@ class LinearTransformation(Transform):
if has_any(sample, PIL.Image.Image):
raise TypeError("LinearTransformation does not work on PIL Images")
def _transform(
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
# Image instance after linear transformation is not Image anymore due to unknown data range
# Thus we will return Tensor for input Image
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
shape = inpt.shape
n = shape[-3] * shape[-2] * shape[-1]
if n != self.transformation_matrix.shape[0]:
......@@ -97,11 +92,15 @@ class LinearTransformation(Transform):
f"Got {inpt.device} vs {self.mean_vector.device}"
)
flat_tensor = inpt.reshape(-1, n) - self.mean_vector
flat_inpt = inpt.reshape(-1, n) - self.mean_vector
transformation_matrix = self.transformation_matrix.to(flat_inpt.dtype)
output = torch.mm(flat_inpt, transformation_matrix)
output = output.reshape(shape)
transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype)
transformed_tensor = torch.mm(flat_tensor, transformation_matrix)
return transformed_tensor.reshape(shape)
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
return output
class Normalize(Transform):
......@@ -120,7 +119,7 @@ class Normalize(Transform):
def _transform(
self, inpt: Union[datapoints.TensorImageType, datapoints.TensorVideoType], params: Dict[str, Any]
) -> torch.Tensor:
) -> Any:
return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment