Unverified Commit 80a708f3 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

use torch.where over boolean masking (#8171)

parent 6640e494
...@@ -595,8 +595,7 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill ...@@ -595,8 +595,7 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] # type: ignore[arg-type] fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] # type: ignore[arg-type]
fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1) fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
if mode == "nearest": if mode == "nearest":
bool_mask = mask < 0.5 float_img = torch.where(mask < 0.5, fill_img.expand_as(float_img), float_img)
float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
else: # 'bilinear' else: # 'bilinear'
# The following is mathematically equivalent to: # The following is mathematically equivalent to:
# img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill # img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment