Unverified Commit 775129be authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Remove non-functional Transforms from presets. (#4952)

parent 4b2ad55f
...@@ -62,7 +62,7 @@ _COMMON_META = { ...@@ -62,7 +62,7 @@ _COMMON_META = {
class R3D_18Weights(Weights): class R3D_18Weights(Weights):
Kinetics400_RefV1 = WeightEntry( Kinetics400_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)), transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"acc@1": 52.75, "acc@1": 52.75,
...@@ -74,7 +74,7 @@ class R3D_18Weights(Weights): ...@@ -74,7 +74,7 @@ class R3D_18Weights(Weights):
class MC3_18Weights(Weights): class MC3_18Weights(Weights):
Kinetics400_RefV1 = WeightEntry( Kinetics400_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)), transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"acc@1": 53.90, "acc@1": 53.90,
...@@ -86,7 +86,7 @@ class MC3_18Weights(Weights): ...@@ -86,7 +86,7 @@ class MC3_18Weights(Weights):
class R2Plus1D_18Weights(Weights): class R2Plus1D_18Weights(Weights):
Kinetics400_RefV1 = WeightEntry( Kinetics400_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)), transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"acc@1": 57.50, "acc@1": 57.50,
......
...@@ -3,8 +3,7 @@ from typing import Dict, Optional, Tuple ...@@ -3,8 +3,7 @@ from typing import Dict, Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from ... import transforms as T from ...transforms import functional as F, InterpolationMode
from ...transforms import functional as F
__all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval"] __all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval"]
...@@ -26,42 +25,47 @@ class ImageNetEval(nn.Module): ...@@ -26,42 +25,47 @@ class ImageNetEval(nn.Module):
resize_size: int = 256, resize_size: int = 256,
mean: Tuple[float, ...] = (0.485, 0.456, 0.406), mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
std: Tuple[float, ...] = (0.229, 0.224, 0.225), std: Tuple[float, ...] = (0.229, 0.224, 0.225),
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None: ) -> None:
super().__init__() super().__init__()
self._resize = T.Resize(resize_size, interpolation=interpolation) self._crop_size = [crop_size]
self._crop = T.CenterCrop(crop_size) self._size = [resize_size]
self._normalize = T.Normalize(mean=mean, std=std) self._mean = list(mean)
self._std = list(std)
self._interpolation = interpolation
def forward(self, img: Tensor) -> Tensor: def forward(self, img: Tensor) -> Tensor:
img = self._crop(self._resize(img)) img = F.resize(img, self._size, interpolation=self._interpolation)
img = F.center_crop(img, self._crop_size)
if not isinstance(img, Tensor): if not isinstance(img, Tensor):
img = F.pil_to_tensor(img) img = F.pil_to_tensor(img)
img = F.convert_image_dtype(img, torch.float) img = F.convert_image_dtype(img, torch.float)
return self._normalize(img) img = F.normalize(img, mean=self._mean, std=self._std)
return img
class Kinect400Eval(nn.Module): class Kinect400Eval(nn.Module):
def __init__( def __init__(
self, self,
resize_size: Tuple[int, int],
crop_size: Tuple[int, int], crop_size: Tuple[int, int],
resize_size: Tuple[int, int],
mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645), mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645),
std: Tuple[float, ...] = (0.22803, 0.22145, 0.216989), std: Tuple[float, ...] = (0.22803, 0.22145, 0.216989),
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None: ) -> None:
super().__init__() super().__init__()
self._convert = T.ConvertImageDtype(torch.float) self._crop_size = list(crop_size)
self._resize = T.Resize(resize_size, interpolation=interpolation) self._size = list(resize_size)
self._normalize = T.Normalize(mean=mean, std=std) self._mean = list(mean)
self._crop = T.CenterCrop(crop_size) self._std = list(std)
self._interpolation = interpolation
def forward(self, vid: Tensor) -> Tensor: def forward(self, vid: Tensor) -> Tensor:
vid = vid.permute(0, 3, 1, 2) # (T, H, W, C) => (T, C, H, W) vid = vid.permute(0, 3, 1, 2) # (T, H, W, C) => (T, C, H, W)
vid = self._convert(vid) vid = F.resize(vid, self._size, interpolation=self._interpolation)
vid = self._resize(vid) vid = F.center_crop(vid, self._crop_size)
vid = self._normalize(vid) vid = F.convert_image_dtype(vid, torch.float)
vid = self._crop(vid) vid = F.normalize(vid, mean=self._mean, std=self._std)
return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W) return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W)
...@@ -71,8 +75,8 @@ class VocEval(nn.Module): ...@@ -71,8 +75,8 @@ class VocEval(nn.Module):
resize_size: int, resize_size: int,
mean: Tuple[float, ...] = (0.485, 0.456, 0.406), mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
std: Tuple[float, ...] = (0.229, 0.224, 0.225), std: Tuple[float, ...] = (0.229, 0.224, 0.225),
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation_target: T.InterpolationMode = T.InterpolationMode.NEAREST, interpolation_target: InterpolationMode = InterpolationMode.NEAREST,
) -> None: ) -> None:
super().__init__() super().__init__()
self._size = [resize_size] self._size = [resize_size]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment