Unverified Commit 775129be authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Remove non-functional Transforms from presets. (#4952)

parent 4b2ad55f
......@@ -62,7 +62,7 @@ _COMMON_META = {
class R3D_18Weights(Weights):
Kinetics400_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
meta={
**_COMMON_META,
"acc@1": 52.75,
......@@ -74,7 +74,7 @@ class R3D_18Weights(Weights):
class MC3_18Weights(Weights):
Kinetics400_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
meta={
**_COMMON_META,
"acc@1": 53.90,
......@@ -86,7 +86,7 @@ class MC3_18Weights(Weights):
class R2Plus1D_18Weights(Weights):
Kinetics400_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
meta={
**_COMMON_META,
"acc@1": 57.50,
......
......@@ -3,8 +3,7 @@ from typing import Dict, Optional, Tuple
import torch
from torch import Tensor, nn
from ... import transforms as T
from ...transforms import functional as F
from ...transforms import functional as F, InterpolationMode
__all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval"]
......@@ -26,42 +25,47 @@ class ImageNetEval(nn.Module):
resize_size: int = 256,
mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self._resize = T.Resize(resize_size, interpolation=interpolation)
self._crop = T.CenterCrop(crop_size)
self._normalize = T.Normalize(mean=mean, std=std)
self._crop_size = [crop_size]
self._size = [resize_size]
self._mean = list(mean)
self._std = list(std)
self._interpolation = interpolation
def forward(self, img: Tensor) -> Tensor:
img = self._crop(self._resize(img))
img = F.resize(img, self._size, interpolation=self._interpolation)
img = F.center_crop(img, self._crop_size)
if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
img = F.convert_image_dtype(img, torch.float)
return self._normalize(img)
img = F.normalize(img, mean=self._mean, std=self._std)
return img
class Kinect400Eval(nn.Module):
def __init__(
self,
resize_size: Tuple[int, int],
crop_size: Tuple[int, int],
resize_size: Tuple[int, int],
mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645),
std: Tuple[float, ...] = (0.22803, 0.22145, 0.216989),
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self._convert = T.ConvertImageDtype(torch.float)
self._resize = T.Resize(resize_size, interpolation=interpolation)
self._normalize = T.Normalize(mean=mean, std=std)
self._crop = T.CenterCrop(crop_size)
self._crop_size = list(crop_size)
self._size = list(resize_size)
self._mean = list(mean)
self._std = list(std)
self._interpolation = interpolation
def forward(self, vid: Tensor) -> Tensor:
vid = vid.permute(0, 3, 1, 2) # (T, H, W, C) => (T, C, H, W)
vid = self._convert(vid)
vid = self._resize(vid)
vid = self._normalize(vid)
vid = self._crop(vid)
vid = F.resize(vid, self._size, interpolation=self._interpolation)
vid = F.center_crop(vid, self._crop_size)
vid = F.convert_image_dtype(vid, torch.float)
vid = F.normalize(vid, mean=self._mean, std=self._std)
return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W)
......@@ -71,8 +75,8 @@ class VocEval(nn.Module):
resize_size: int,
mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
interpolation_target: T.InterpolationMode = T.InterpolationMode.NEAREST,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation_target: InterpolationMode = InterpolationMode.NEAREST,
) -> None:
super().__init__()
self._size = [resize_size]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment