Unverified Commit 9528d963 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add multi-weight support for VideoResNet (#4770)

* Add mutli-weight support for VideoResNet.

* Fix linter.

* Minor refactoring.

* Update comments.
parent d367a01a
...@@ -25,6 +25,11 @@ def _build_model(fn, **kwargs): ...@@ -25,6 +25,11 @@ def _build_model(fn, **kwargs):
return model.eval() return model.eval()
def get_models_with_module_names(module):
module_name = module.__name__.split(".")[-1]
return [(fn, module_name) for fn in TM.get_models_from_module(module)]
def test_get_weight(): def test_get_weight():
fn = models.resnet50 fn = models.resnet50
weight_name = "ImageNet1K_RefV2" weight_name = "ImageNet1K_RefV2"
...@@ -45,16 +50,35 @@ def test_segmentation_model(model_fn, dev): ...@@ -45,16 +50,35 @@ def test_segmentation_model(model_fn, dev):
TM.test_segmentation_model(model_fn, dev) TM.test_segmentation_model(model_fn, dev)
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models) + TM.get_models_from_module(models.segmentation)) @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.video))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
def test_video_model(model_fn, dev):
TM.test_video_model(model_fn, dev)
@pytest.mark.parametrize(
"model_fn, module_name",
get_models_with_module_names(models)
+ get_models_with_module_names(models.segmentation)
+ get_models_with_module_names(models.video),
)
@pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled") @pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
def test_old_vs_new_factory(model_fn, dev): def test_old_vs_new_factory(model_fn, module_name, dev):
defaults = { defaults = {
"pretrained": True, "models": {
"input_shape": (1, 3, 224, 224), "input_shape": (1, 3, 224, 224),
},
"segmentation": {
"input_shape": (1, 3, 520, 520),
},
"video": {
"input_shape": (1, 3, 4, 112, 112),
},
} }
model_name = model_fn.__name__ model_name = model_fn.__name__
kwargs = {**defaults, **TM._model_params.get(model_name, {})} kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})}
input_shape = kwargs.pop("input_shape") input_shape = kwargs.pop("input_shape")
x = torch.rand(input_shape).to(device=dev) x = torch.rand(input_shape).to(device=dev)
......
from typing import Tuple, Optional, Callable, List, Type, Any, Union from typing import Tuple, Optional, Callable, List, Sequence, Type, Any, Union
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
...@@ -191,7 +191,7 @@ class VideoResNet(nn.Module): ...@@ -191,7 +191,7 @@ class VideoResNet(nn.Module):
def __init__( def __init__(
self, self,
block: Type[Union[BasicBlock, Bottleneck]], block: Type[Union[BasicBlock, Bottleneck]],
conv_makers: List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]], conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
layers: List[int], layers: List[int],
stem: Callable[..., nn.Module], stem: Callable[..., nn.Module],
num_classes: int = 400, num_classes: int = 400,
......
...@@ -8,3 +8,4 @@ from .mnasnet import * ...@@ -8,3 +8,4 @@ from .mnasnet import *
from . import detection from . import detection
from . import quantization from . import quantization
from . import segmentation from . import segmentation
from . import video
...@@ -1126,3 +1126,407 @@ _VOC_CATEGORIES = [ ...@@ -1126,3 +1126,407 @@ _VOC_CATEGORIES = [
"train", "train",
"tvmonitor", "tvmonitor",
] ]
# To be replaced with torchvision.datasets.find("kinetics400").info.categories
_KINETICS400_CATEGORIES = [
"abseiling",
"air drumming",
"answering questions",
"applauding",
"applying cream",
"archery",
"arm wrestling",
"arranging flowers",
"assembling computer",
"auctioning",
"baby waking up",
"baking cookies",
"balloon blowing",
"bandaging",
"barbequing",
"bartending",
"beatboxing",
"bee keeping",
"belly dancing",
"bench pressing",
"bending back",
"bending metal",
"biking through snow",
"blasting sand",
"blowing glass",
"blowing leaves",
"blowing nose",
"blowing out candles",
"bobsledding",
"bookbinding",
"bouncing on trampoline",
"bowling",
"braiding hair",
"breading or breadcrumbing",
"breakdancing",
"brush painting",
"brushing hair",
"brushing teeth",
"building cabinet",
"building shed",
"bungee jumping",
"busking",
"canoeing or kayaking",
"capoeira",
"carrying baby",
"cartwheeling",
"carving pumpkin",
"catching fish",
"catching or throwing baseball",
"catching or throwing frisbee",
"catching or throwing softball",
"celebrating",
"changing oil",
"changing wheel",
"checking tires",
"cheerleading",
"chopping wood",
"clapping",
"clay pottery making",
"clean and jerk",
"cleaning floor",
"cleaning gutters",
"cleaning pool",
"cleaning shoes",
"cleaning toilet",
"cleaning windows",
"climbing a rope",
"climbing ladder",
"climbing tree",
"contact juggling",
"cooking chicken",
"cooking egg",
"cooking on campfire",
"cooking sausages",
"counting money",
"country line dancing",
"cracking neck",
"crawling baby",
"crossing river",
"crying",
"curling hair",
"cutting nails",
"cutting pineapple",
"cutting watermelon",
"dancing ballet",
"dancing charleston",
"dancing gangnam style",
"dancing macarena",
"deadlifting",
"decorating the christmas tree",
"digging",
"dining",
"disc golfing",
"diving cliff",
"dodgeball",
"doing aerobics",
"doing laundry",
"doing nails",
"drawing",
"dribbling basketball",
"drinking",
"drinking beer",
"drinking shots",
"driving car",
"driving tractor",
"drop kicking",
"drumming fingers",
"dunking basketball",
"dying hair",
"eating burger",
"eating cake",
"eating carrots",
"eating chips",
"eating doughnuts",
"eating hotdog",
"eating ice cream",
"eating spaghetti",
"eating watermelon",
"egg hunting",
"exercising arm",
"exercising with an exercise ball",
"extinguishing fire",
"faceplanting",
"feeding birds",
"feeding fish",
"feeding goats",
"filling eyebrows",
"finger snapping",
"fixing hair",
"flipping pancake",
"flying kite",
"folding clothes",
"folding napkins",
"folding paper",
"front raises",
"frying vegetables",
"garbage collecting",
"gargling",
"getting a haircut",
"getting a tattoo",
"giving or receiving award",
"golf chipping",
"golf driving",
"golf putting",
"grinding meat",
"grooming dog",
"grooming horse",
"gymnastics tumbling",
"hammer throw",
"headbanging",
"headbutting",
"high jump",
"high kick",
"hitting baseball",
"hockey stop",
"holding snake",
"hopscotch",
"hoverboarding",
"hugging",
"hula hooping",
"hurdling",
"hurling (sport)",
"ice climbing",
"ice fishing",
"ice skating",
"ironing",
"javelin throw",
"jetskiing",
"jogging",
"juggling balls",
"juggling fire",
"juggling soccer ball",
"jumping into pool",
"jumpstyle dancing",
"kicking field goal",
"kicking soccer ball",
"kissing",
"kitesurfing",
"knitting",
"krumping",
"laughing",
"laying bricks",
"long jump",
"lunge",
"making a cake",
"making a sandwich",
"making bed",
"making jewelry",
"making pizza",
"making snowman",
"making sushi",
"making tea",
"marching",
"massaging back",
"massaging feet",
"massaging legs",
"massaging person's head",
"milking cow",
"mopping floor",
"motorcycling",
"moving furniture",
"mowing lawn",
"news anchoring",
"opening bottle",
"opening present",
"paragliding",
"parasailing",
"parkour",
"passing American football (in game)",
"passing American football (not in game)",
"peeling apples",
"peeling potatoes",
"petting animal (not cat)",
"petting cat",
"picking fruit",
"planting trees",
"plastering",
"playing accordion",
"playing badminton",
"playing bagpipes",
"playing basketball",
"playing bass guitar",
"playing cards",
"playing cello",
"playing chess",
"playing clarinet",
"playing controller",
"playing cricket",
"playing cymbals",
"playing didgeridoo",
"playing drums",
"playing flute",
"playing guitar",
"playing harmonica",
"playing harp",
"playing ice hockey",
"playing keyboard",
"playing kickball",
"playing monopoly",
"playing organ",
"playing paintball",
"playing piano",
"playing poker",
"playing recorder",
"playing saxophone",
"playing squash or racquetball",
"playing tennis",
"playing trombone",
"playing trumpet",
"playing ukulele",
"playing violin",
"playing volleyball",
"playing xylophone",
"pole vault",
"presenting weather forecast",
"pull ups",
"pumping fist",
"pumping gas",
"punching bag",
"punching person (boxing)",
"push up",
"pushing car",
"pushing cart",
"pushing wheelchair",
"reading book",
"reading newspaper",
"recording music",
"riding a bike",
"riding camel",
"riding elephant",
"riding mechanical bull",
"riding mountain bike",
"riding mule",
"riding or walking with horse",
"riding scooter",
"riding unicycle",
"ripping paper",
"robot dancing",
"rock climbing",
"rock scissors paper",
"roller skating",
"running on treadmill",
"sailing",
"salsa dancing",
"sanding floor",
"scrambling eggs",
"scuba diving",
"setting table",
"shaking hands",
"shaking head",
"sharpening knives",
"sharpening pencil",
"shaving head",
"shaving legs",
"shearing sheep",
"shining shoes",
"shooting basketball",
"shooting goal (soccer)",
"shot put",
"shoveling snow",
"shredding paper",
"shuffling cards",
"side kick",
"sign language interpreting",
"singing",
"situp",
"skateboarding",
"ski jumping",
"skiing (not slalom or crosscountry)",
"skiing crosscountry",
"skiing slalom",
"skipping rope",
"skydiving",
"slacklining",
"slapping",
"sled dog racing",
"smoking",
"smoking hookah",
"snatch weight lifting",
"sneezing",
"sniffing",
"snorkeling",
"snowboarding",
"snowkiting",
"snowmobiling",
"somersaulting",
"spinning poi",
"spray painting",
"spraying",
"springboard diving",
"squat",
"sticking tongue out",
"stomping grapes",
"stretching arm",
"stretching leg",
"strumming guitar",
"surfing crowd",
"surfing water",
"sweeping floor",
"swimming backstroke",
"swimming breast stroke",
"swimming butterfly stroke",
"swing dancing",
"swinging legs",
"swinging on something",
"sword fighting",
"tai chi",
"taking a shower",
"tango dancing",
"tap dancing",
"tapping guitar",
"tapping pen",
"tasting beer",
"tasting food",
"testifying",
"texting",
"throwing axe",
"throwing ball",
"throwing discus",
"tickling",
"tobogganing",
"tossing coin",
"tossing salad",
"training dog",
"trapezing",
"trimming or shaving beard",
"trimming trees",
"triple jump",
"tying bow tie",
"tying knot (not on a tie)",
"tying tie",
"unboxing",
"unloading truck",
"using computer",
"using remote controller (not gaming)",
"using segway",
"vault",
"waiting in line",
"walking the dog",
"washing dishes",
"washing feet",
"washing hair",
"washing hands",
"water skiing",
"water sliding",
"watering plants",
"waxing back",
"waxing chest",
"waxing eyebrows",
"waxing legs",
"weaving basket",
"welding",
"whistling",
"windsurfing",
"wrapping present",
"wrestling",
"writing",
"yawning",
"yoga",
"zumba",
]
import warnings
from functools import partial
from typing import Any, Callable, List, Optional, Sequence, Type, Union
from torch import nn
from torchvision.transforms.functional import InterpolationMode
from ....models.video.resnet import (
BasicBlock,
BasicStem,
Bottleneck,
Conv2Plus1D,
Conv3DSimple,
Conv3DNoTemporal,
R2Plus1dStem,
VideoResNet,
)
from ...transforms.presets import Kinect400Eval
from .._api import Weights, WeightEntry
from .._meta import _KINETICS400_CATEGORIES
__all__ = [
"VideoResNet",
"R3D_18Weights",
"MC3_18Weights",
"R2Plus1D_18Weights",
"r3d_18",
"mc3_18",
"r2plus1d_18",
]
def _video_resnet(
block: Type[Union[BasicBlock, Bottleneck]],
conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
layers: List[int],
stem: Callable[..., nn.Module],
weights: Optional[Weights],
progress: bool,
**kwargs: Any,
) -> VideoResNet:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
model = VideoResNet(block, conv_makers, layers, stem, **kwargs)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
return model
_common_meta = {"size": (112, 112), "categories": _KINETICS400_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
class R3D_18Weights(Weights):
Kinetics400_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
"acc@1": 52.75,
"acc@5": 75.45,
},
)
class MC3_18Weights(Weights):
Kinetics400_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
"acc@1": 53.90,
"acc@5": 76.29,
},
)
class R2Plus1D_18Weights(Weights):
Kinetics400_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)),
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
"acc@1": 57.50,
"acc@5": 78.81,
},
)
def r3d_18(weights: Optional[R3D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = R3D_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None
weights = R3D_18Weights.verify(weights)
return _video_resnet(
BasicBlock,
[Conv3DSimple] * 4,
[2, 2, 2, 2],
BasicStem,
weights,
progress,
**kwargs,
)
def mc3_18(weights: Optional[MC3_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = MC3_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None
weights = MC3_18Weights.verify(weights)
return _video_resnet(
BasicBlock,
[Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item]
[2, 2, 2, 2],
BasicStem,
weights,
progress,
**kwargs,
)
def r2plus1d_18(weights: Optional[R2Plus1D_18Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = R2Plus1D_18Weights.Kinetics400_RefV1 if kwargs.pop("pretrained") else None
weights = R2Plus1D_18Weights.verify(weights)
return _video_resnet(
BasicBlock,
[Conv2Plus1D] * 4,
[2, 2, 2, 2],
R2Plus1dStem,
weights,
progress,
**kwargs,
)
...@@ -7,7 +7,7 @@ from ... import transforms as T ...@@ -7,7 +7,7 @@ from ... import transforms as T
from ...transforms import functional as F from ...transforms import functional as F
__all__ = ["CocoEval", "ImageNetEval", "VocEval"] __all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval"]
class CocoEval(nn.Module): class CocoEval(nn.Module):
...@@ -41,6 +41,30 @@ class ImageNetEval(nn.Module): ...@@ -41,6 +41,30 @@ class ImageNetEval(nn.Module):
return self._normalize(img) return self._normalize(img)
class Kinect400Eval(nn.Module):
def __init__(
self,
resize_size: Tuple[int, int],
crop_size: Tuple[int, int],
mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645),
std: Tuple[float, ...] = (0.22803, 0.22145, 0.216989),
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self._convert = T.ConvertImageDtype(torch.float)
self._resize = T.Resize(resize_size, interpolation=interpolation)
self._normalize = T.Normalize(mean=mean, std=std)
self._crop = T.CenterCrop(crop_size)
def forward(self, vid: Tensor) -> Tensor:
vid = vid.permute(0, 3, 1, 2) # (T, H, W, C) => (T, C, H, W)
vid = self._convert(vid)
vid = self._resize(vid)
vid = self._normalize(vid)
vid = self._crop(vid)
return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W)
class VocEval(nn.Module): class VocEval(nn.Module):
def __init__( def __init__(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment