Unverified Commit 9f0afd55 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Replaced ConvertImageDtype by ToDtype in reference scripts (#7862)


Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent 4491ca2e
......@@ -61,7 +61,7 @@ class ClassificationPresetTrain:
transforms.extend(
[
T.ConvertImageDtype(torch.float),
T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
)
......@@ -106,7 +106,7 @@ class ClassificationPresetEval:
transforms.append(T.PILToTensor())
transforms += [
T.ConvertImageDtype(torch.float),
T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
......
......@@ -73,7 +73,7 @@ class DetectionPresetTrain:
# Note: we could just convert to pure tensors even in v2.
transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
transforms += [T.ConvertImageDtype(torch.float)]
transforms += [T.ToDtype(torch.float, scale=True)]
if use_v2:
transforms += [
......@@ -103,7 +103,7 @@ class DetectionPresetEval:
else:
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
transforms += [T.ConvertImageDtype(torch.float)]
transforms += [T.ToDtype(torch.float, scale=True)]
if use_v2:
transforms += [T.ToPureTensor()]
......
......@@ -53,14 +53,17 @@ class PILToTensor(nn.Module):
return image, target
class ConvertImageDtype(nn.Module):
def __init__(self, dtype: torch.dtype) -> None:
class ToDtype(nn.Module):
def __init__(self, dtype: torch.dtype, scale: bool = False) -> None:
super().__init__()
self.dtype = dtype
self.scale = scale
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if not self.scale:
return image.to(dtype=self.dtype), target
image = F.convert_image_dtype(image, self.dtype)
return image, target
......
......@@ -60,7 +60,7 @@ class SegmentationPresetTrain:
]
else:
# No need to explicitly convert masks as they're magically int64 already
transforms += [T.ConvertImageDtype(torch.float)]
transforms += [T.ToDtype(torch.float, scale=True)]
transforms += [T.Normalize(mean=mean, std=std)]
if use_v2:
......@@ -97,7 +97,7 @@ class SegmentationPresetEval:
transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
transforms += [
T.ConvertImageDtype(torch.float),
T.ToDtype(torch.float, scale=True),
T.Normalize(mean=mean, std=std),
]
if use_v2:
......
......@@ -81,11 +81,14 @@ class PILToTensor:
return image, target
class ConvertImageDtype:
def __init__(self, dtype):
class ToDtype:
def __init__(self, dtype, scale=False):
self.dtype = dtype
self.scale = scale
def __call__(self, image, target):
if not self.scale:
return image.to(dtype=self.dtype), target
image = F.convert_image_dtype(image, self.dtype)
return image, target
......
......@@ -78,6 +78,6 @@ class CocoDetectionToVOCSegmentation(v2.Transform):
def forward(self, image, target):
segmentation_mask = self._coco_detection_masks_to_voc_segmentation_mask(target)
if segmentation_mask is None:
segmentation_mask = torch.zeros(v2.functional.get_spatial_size(image), dtype=torch.uint8)
segmentation_mask = torch.zeros(v2.functional.get_size(image), dtype=torch.uint8)
return image, datapoints.Mask(segmentation_mask)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment