Unverified Commit 6518372e authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Added PILToTensor and ConvertImageDtype classes in reference scripts (#4495)

* Added PILToTensor and ConvertImageDtype classes in reference scripts

* Addressed review comments

* Fixed TypeError

* Addressed review comment
parent 055708d2
import torch
import transforms as T import transforms as T
...@@ -6,7 +8,8 @@ class DetectionPresetTrain: ...@@ -6,7 +8,8 @@ class DetectionPresetTrain:
if data_augmentation == 'hflip': if data_augmentation == 'hflip':
self.transforms = T.Compose([ self.transforms = T.Compose([
T.RandomHorizontalFlip(p=hflip_prob), T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(), T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]) ])
elif data_augmentation == 'ssd': elif data_augmentation == 'ssd':
self.transforms = T.Compose([ self.transforms = T.Compose([
...@@ -14,13 +17,15 @@ class DetectionPresetTrain: ...@@ -14,13 +17,15 @@ class DetectionPresetTrain:
T.RandomZoomOut(fill=list(mean)), T.RandomZoomOut(fill=list(mean)),
T.RandomIoUCrop(), T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob), T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(), T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]) ])
elif data_augmentation == 'ssdlite': elif data_augmentation == 'ssdlite':
self.transforms = T.Compose([ self.transforms = T.Compose([
T.RandomIoUCrop(), T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob), T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(), T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]) ])
else: else:
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"') raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
......
from typing import List, Tuple, Dict, Optional
import torch import torch
import torchvision import torchvision
from torch import nn, Tensor from torch import nn, Tensor
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T from torchvision.transforms import transforms as T
from typing import List, Tuple, Dict, Optional
def _flip_coco_person_keypoints(kps, width): def _flip_coco_person_keypoints(kps, width):
...@@ -52,6 +52,24 @@ class ToTensor(nn.Module): ...@@ -52,6 +52,24 @@ class ToTensor(nn.Module):
return image, target return image, target
class PILToTensor(nn.Module):
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.pil_to_tensor(image)
return image, target
class ConvertImageDtype(nn.Module):
def __init__(self, dtype: torch.dtype) -> None:
super().__init__()
self.dtype = dtype
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.convert_image_dtype(image, self.dtype)
return image, target
class RandomIoUCrop(nn.Module): class RandomIoUCrop(nn.Module):
def __init__(self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ratio: float = 0.5, def __init__(self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2.0, sampler_options: Optional[List[float]] = None, trials: int = 40): max_aspect_ratio: float = 2.0, sampler_options: Optional[List[float]] = None, trials: int = 40):
......
import torch
import transforms as T import transforms as T
...@@ -11,7 +13,8 @@ class SegmentationPresetTrain: ...@@ -11,7 +13,8 @@ class SegmentationPresetTrain:
trans.append(T.RandomHorizontalFlip(hflip_prob)) trans.append(T.RandomHorizontalFlip(hflip_prob))
trans.extend([ trans.extend([
T.RandomCrop(crop_size), T.RandomCrop(crop_size),
T.ToTensor(), T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std), T.Normalize(mean=mean, std=std),
]) ])
self.transforms = T.Compose(trans) self.transforms = T.Compose(trans)
...@@ -24,7 +27,8 @@ class SegmentationPresetEval: ...@@ -24,7 +27,8 @@ class SegmentationPresetEval:
def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = T.Compose([ self.transforms = T.Compose([
T.RandomResize(base_size, base_size), T.RandomResize(base_size, base_size),
T.ToTensor(), T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std), T.Normalize(mean=mean, std=std),
]) ])
......
import numpy as np
from PIL import Image
import random import random
import numpy as np
import torch import torch
from torchvision import transforms as T from torchvision import transforms as T
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
...@@ -75,14 +74,22 @@ class CenterCrop(object): ...@@ -75,14 +74,22 @@ class CenterCrop(object):
return image, target return image, target
class ToTensor(object): class PILToTensor:
def __call__(self, image, target): def __call__(self, image, target):
image = F.pil_to_tensor(image) image = F.pil_to_tensor(image)
image = F.convert_image_dtype(image)
target = torch.as_tensor(np.array(target), dtype=torch.int64) target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target return image, target
class ConvertImageDtype:
def __init__(self, dtype):
self.dtype = dtype
def __call__(self, image, target):
image = F.convert_image_dtype(image, self.dtype)
return image, target
class Normalize(object): class Normalize(object):
def __init__(self, mean, std): def __init__(self, mean, std):
self.mean = mean self.mean = mean
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment