Unverified Commit 9f0afd55 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Replaced ConvertImageDtype by ToDtype in reference scripts (#7862)


Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent 4491ca2e
...@@ -61,7 +61,7 @@ class ClassificationPresetTrain: ...@@ -61,7 +61,7 @@ class ClassificationPresetTrain:
transforms.extend( transforms.extend(
[ [
T.ConvertImageDtype(torch.float), T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std), T.Normalize(mean=mean, std=std),
] ]
) )
...@@ -106,7 +106,7 @@ class ClassificationPresetEval: ...@@ -106,7 +106,7 @@ class ClassificationPresetEval:
transforms.append(T.PILToTensor()) transforms.append(T.PILToTensor())
transforms += [ transforms += [
T.ConvertImageDtype(torch.float), T.ToDtype(torch.float, scale=True) if use_v2 else T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std), T.Normalize(mean=mean, std=std),
] ]
......
...@@ -73,7 +73,7 @@ class DetectionPresetTrain: ...@@ -73,7 +73,7 @@ class DetectionPresetTrain:
# Note: we could just convert to pure tensors even in v2. # Note: we could just convert to pure tensors even in v2.
transforms += [T.ToImage() if use_v2 else T.PILToTensor()] transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
transforms += [T.ConvertImageDtype(torch.float)] transforms += [T.ToDtype(torch.float, scale=True)]
if use_v2: if use_v2:
transforms += [ transforms += [
...@@ -103,7 +103,7 @@ class DetectionPresetEval: ...@@ -103,7 +103,7 @@ class DetectionPresetEval:
else: else:
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
transforms += [T.ConvertImageDtype(torch.float)] transforms += [T.ToDtype(torch.float, scale=True)]
if use_v2: if use_v2:
transforms += [T.ToPureTensor()] transforms += [T.ToPureTensor()]
......
...@@ -53,14 +53,17 @@ class PILToTensor(nn.Module): ...@@ -53,14 +53,17 @@ class PILToTensor(nn.Module):
return image, target return image, target
class ConvertImageDtype(nn.Module): class ToDtype(nn.Module):
def __init__(self, dtype: torch.dtype) -> None: def __init__(self, dtype: torch.dtype, scale: bool = False) -> None:
super().__init__() super().__init__()
self.dtype = dtype self.dtype = dtype
self.scale = scale
def forward( def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if not self.scale:
return image.to(dtype=self.dtype), target
image = F.convert_image_dtype(image, self.dtype) image = F.convert_image_dtype(image, self.dtype)
return image, target return image, target
......
...@@ -60,7 +60,7 @@ class SegmentationPresetTrain: ...@@ -60,7 +60,7 @@ class SegmentationPresetTrain:
] ]
else: else:
# No need to explicitly convert masks as they're magically int64 already # No need to explicitly convert masks as they're magically int64 already
transforms += [T.ConvertImageDtype(torch.float)] transforms += [T.ToDtype(torch.float, scale=True)]
transforms += [T.Normalize(mean=mean, std=std)] transforms += [T.Normalize(mean=mean, std=std)]
if use_v2: if use_v2:
...@@ -97,7 +97,7 @@ class SegmentationPresetEval: ...@@ -97,7 +97,7 @@ class SegmentationPresetEval:
transforms += [T.ToImage() if use_v2 else T.PILToTensor()] transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
transforms += [ transforms += [
T.ConvertImageDtype(torch.float), T.ToDtype(torch.float, scale=True),
T.Normalize(mean=mean, std=std), T.Normalize(mean=mean, std=std),
] ]
if use_v2: if use_v2:
......
...@@ -81,11 +81,14 @@ class PILToTensor: ...@@ -81,11 +81,14 @@ class PILToTensor:
return image, target return image, target
class ConvertImageDtype: class ToDtype:
def __init__(self, dtype): def __init__(self, dtype, scale=False):
self.dtype = dtype self.dtype = dtype
self.scale = scale
def __call__(self, image, target): def __call__(self, image, target):
if not self.scale:
return image.to(dtype=self.dtype), target
image = F.convert_image_dtype(image, self.dtype) image = F.convert_image_dtype(image, self.dtype)
return image, target return image, target
......
...@@ -78,6 +78,6 @@ class CocoDetectionToVOCSegmentation(v2.Transform): ...@@ -78,6 +78,6 @@ class CocoDetectionToVOCSegmentation(v2.Transform):
def forward(self, image, target): def forward(self, image, target):
segmentation_mask = self._coco_detection_masks_to_voc_segmentation_mask(target) segmentation_mask = self._coco_detection_masks_to_voc_segmentation_mask(target)
if segmentation_mask is None: if segmentation_mask is None:
segmentation_mask = torch.zeros(v2.functional.get_spatial_size(image), dtype=torch.uint8) segmentation_mask = torch.zeros(v2.functional.get_size(image), dtype=torch.uint8)
return image, datapoints.Mask(segmentation_mask) return image, datapoints.Mask(segmentation_mask)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment