Unverified Commit c6a03c76 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Replaced to_tensor() with pil_to_tensor() + convert_image_dtype() (#4452)

parent c7120163
......@@ -47,7 +47,8 @@ class RandomHorizontalFlip(T.RandomHorizontalFlip):
class ToTensor(nn.Module):
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.to_tensor(image)
image = F.pil_to_tensor(image)
image = F.convert_image_dtype(image)
return image, target
......@@ -231,7 +232,8 @@ class RandomPhotometricDistort(nn.Module):
is_pil = F._is_pil_image(image)
if is_pil:
image = F.to_tensor(image)
image = F.pil_to_tensor(image)
image = F.convert_image_dtype(image)
image = image[..., permutation, :, :]
if is_pil:
image = F.to_pil_image(image)
......
......@@ -77,7 +77,8 @@ class CenterCrop(object):
class ToTensor(object):
def __call__(self, image, target):
image = F.to_tensor(image)
image = F.pil_to_tensor(image)
image = F.convert_image_dtype(image)
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment