"vscode:/vscode.git/clone" did not exist on "b2ce5906f234124cdd0a09a25fc713e1dc955a35"
Unverified Commit efb07368 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add tensor kernels for normalize and erase (#5462)

* add tensor kernels for normalize and erase

* add image tensor assertion
parent c6b447b7
......@@ -338,30 +338,9 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(normalize)
if not isinstance(tensor, torch.Tensor):
raise TypeError(f"Input tensor should be a torch tensor. Got {type(tensor)}.")
raise TypeError(f"img should be Tensor Image. Got {type(tensor)}")
if not tensor.is_floating_point():
raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")
if tensor.ndim < 3:
raise ValueError(
f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
)
if not inplace:
tensor = tensor.clone()
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any():
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
tensor.sub_(mean).div_(std)
return tensor
return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
def resize(
......@@ -1281,11 +1260,7 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
if not isinstance(img, torch.Tensor):
raise TypeError(f"img should be Tensor Image. Got {type(img)}")
if not inplace:
img = img.clone()
img[..., i : i + h, j : j + w] = v
return img
return F_t.erase(img, i, j, h, w, v, inplace=inplace)
def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor:
......
......@@ -918,3 +918,40 @@ def equalize(img: Tensor) -> Tensor:
return _equalize_single_image(img)
return torch.stack([_equalize_single_image(x) for x in img])
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
_assert_image_tensor(tensor)
if not tensor.is_floating_point():
raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.")
if tensor.ndim < 3:
raise ValueError(
f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}"
)
if not inplace:
tensor = tensor.clone()
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any():
raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.")
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
tensor.sub_(mean).div_(std)
return tensor
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
_assert_image_tensor(img)
if not inplace:
img = img.clone()
img[..., i : i + h, j : j + w] = v
return img
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment