"tests/vscode:/vscode.git/clone" did not exist on "c4f0fbe6cae34b7248b7965b750a49d716f3ca3f"
Unverified Commit 11bd2eaa authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Port Multi-weight support from prototype to main (#5618)



* Moving basefiles outside of prototype and porting Alexnet, ConvNext, Densenet and EfficientNet.

* Porting googlenet

* Porting inception

* Porting mnasnet

* Porting mobilenetv2

* Porting mobilenetv3

* Porting regnet

* Porting resnet

* Porting shufflenetv2

* Porting squeezenet

* Porting vgg

* Porting vit

* Fix docstrings

* Fixing imports

* Adding missing import

* Fix mobilenet imports

* Fix tests

* Fix prototype tests

* Exclude get_weight from models on test

* Fix init files

* Porting googlenet

* Porting inception

* porting mobilenetv2

* porting mobilenetv3

* porting resnet

* porting shufflenetv2

* Fix test and linter

* Fixing docs.

* Porting Detection models (#5617)

* fix inits

* fix docs

* Port faster_rcnn

* Port fcos

* Port keypoint_rcnn

* Port mask_rcnn

* Port retinanet

* Port ssd

* Port ssdlite

* Fix linter

* Fixing tests

* Fixing tests

* Fixing vgg test

* Porting Optical Flow, Segmentation, Video models (#5619)

* Porting raft

* Porting video resnet

* Porting deeplabv3

* Porting fcn and lraspp

* Fixing the tests and linter

* Porting docs, examples, tutorials and galleries (#5620)

* Fix examples, tutorials and gallery

* Update gallery/plot_optical_flow.py
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Fix import

* Revert hardcoded normalization

* fix uncommitted changes

* Fix bug

* Fix more bugs

* Making resize optional for segmentation

* Fixing preset

* Fix mypy

* Fixing documentation strings

* Fix flake8

* minor refactoring
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>

* Resolve conflict

* Porting model tests (#5622)

* Porting tests

* Remove unnecessary variable

* Fix linter

* Move prototype to extended tests

* Fix download models job

* Update CI on Multiweight branch to use the new weight download approach (#5628)

* port Pad to prototype transforms (#5621)

* port Pad to prototype transforms

* use literal

* Bump up LibTorchvision version number for Podspec to release Cocoapods (#5624)
Co-authored-by: default avatarAnton Thomma <anton@pri.co.nz>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* pre-download model weights in CI docs build (#5625)

* pre-download model weights in CI docs build

* move changes into template

* change docs image

* Regenerated config.yml
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarAnton Thomma <11010310+thommaa@users.noreply.github.com>
Co-authored-by: default avatarAnton Thomma <anton@pri.co.nz>

* Porting reference scripts and updating presets (#5629)

* Making _preset.py classes

* Remove support of targets on presets.

* Rewriting the video preset

* Adding tests to check that the bundled transforms are JIT scriptable

* Rename all presets from *Eval to *Inference

* Minor refactoring

* Remove --prototype and --pretrained from reference scripts

* remove  pretained_backbone refs

* Corrections and simplifications

* Fixing bug

* Fixing linter

* Fix flake8

* restore documentation example

* minor fixes

* fix optical flow missing param

* Fixing commands

* Adding weights_backbone support in detection and segmentation

* Updating the commands for InceptionV3

* Setting `weights_backbone` to its fully BC value (#5653)

* Replace default `weights_backbone=None` with its BC values.

* Fixing tests

* Fix linter

* Update docs.

* Update preprocessing on reference scripts.

* Change qat/ptq to their full values.

* Refactoring preprocessing

* Fix video preset

* No initialization on VGG if pretrained

* Fix warning messages for backbone utils.

* Adding star to all preset constructors.

* Fix mypy.
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarAnton Thomma <11010310+thommaa@users.noreply.github.com>
Co-authored-by: default avatarAnton Thomma <anton@pri.co.nz>
parent 375e4ab2
from functools import partial
from typing import Any, Callable, List, Optional, Sequence, Type, Union
from torch import nn
from torchvision.prototype.transforms import VideoClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.video.resnet import (
BasicBlock,
BasicStem,
Bottleneck,
Conv2Plus1D,
Conv3DSimple,
Conv3DNoTemporal,
R2Plus1dStem,
VideoResNet,
)
from .._api import WeightsEnum, Weights
from .._meta import _KINETICS400_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
"VideoResNet",
"R3D_18_Weights",
"MC3_18_Weights",
"R2Plus1D_18_Weights",
"r3d_18",
"mc3_18",
"r2plus1d_18",
]
def _video_resnet(
block: Type[Union[BasicBlock, Bottleneck]],
conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
layers: List[int],
stem: Callable[..., nn.Module],
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> VideoResNet:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VideoResNet(block, conv_makers, layers, stem, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_COMMON_META = {
"task": "video_classification",
"publication_year": 2017,
"size": (112, 112),
"min_size": (1, 1),
"categories": _KINETICS400_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
}
class R3D_18_Weights(WeightsEnum):
KINETICS400_V1 = Weights(
url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)),
meta={
**_COMMON_META,
"architecture": "R3D",
"num_params": 33371472,
"acc@1": 52.75,
"acc@5": 75.45,
},
)
DEFAULT = KINETICS400_V1
class MC3_18_Weights(WeightsEnum):
KINETICS400_V1 = Weights(
url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)),
meta={
**_COMMON_META,
"architecture": "MC3",
"num_params": 11695440,
"acc@1": 53.90,
"acc@5": 76.29,
},
)
DEFAULT = KINETICS400_V1
class R2Plus1D_18_Weights(WeightsEnum):
KINETICS400_V1 = Weights(
url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)),
meta={
**_COMMON_META,
"architecture": "R(2+1)D",
"num_params": 31505325,
"acc@1": 57.50,
"acc@5": 78.81,
},
)
DEFAULT = KINETICS400_V1
@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1))
def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
weights = R3D_18_Weights.verify(weights)
return _video_resnet(
BasicBlock,
[Conv3DSimple] * 4,
[2, 2, 2, 2],
BasicStem,
weights,
progress,
**kwargs,
)
@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1))
def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
weights = MC3_18_Weights.verify(weights)
return _video_resnet(
BasicBlock,
[Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item]
[2, 2, 2, 2],
BasicStem,
weights,
progress,
**kwargs,
)
@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1))
def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
weights = R2Plus1D_18_Weights.verify(weights)
return _video_resnet(
BasicBlock,
[Conv2Plus1D] * 4,
[2, 2, 2, 2],
R2Plus1dStem,
weights,
progress,
**kwargs,
)
# References:
# https://github.com/google-research/vision_transformer
# https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/vision_transformer.py
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.vision_transformer import VisionTransformer, interpolate_embeddings # noqa: F401
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = [
"VisionTransformer",
"ViT_B_16_Weights",
"ViT_B_32_Weights",
"ViT_L_16_Weights",
"ViT_L_32_Weights",
"vit_b_16",
"vit_b_32",
"vit_l_16",
"vit_l_32",
]
_COMMON_META = {
"task": "image_classification",
"architecture": "ViT",
"publication_year": 2020,
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}
class ViT_B_16_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vit_b_16-c867db91.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 86567656,
"size": (224, 224),
"min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_16",
"acc@1": 81.072,
"acc@5": 95.318,
},
)
DEFAULT = IMAGENET1K_V1
class ViT_B_32_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 88224232,
"size": (224, 224),
"min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_b_32",
"acc@1": 75.912,
"acc@5": 92.466,
},
)
DEFAULT = IMAGENET1K_V1
class ViT_L_16_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth",
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=242),
meta={
**_COMMON_META,
"num_params": 304326632,
"size": (224, 224),
"min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_16",
"acc@1": 79.662,
"acc@5": 94.638,
},
)
DEFAULT = IMAGENET1K_V1
class ViT_L_32_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vit_l_32-c7638314.pth",
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 306535400,
"size": (224, 224),
"min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#vit_l_32",
"acc@1": 76.972,
"acc@5": 93.07,
},
)
DEFAULT = IMAGENET1K_V1
def _vision_transformer(
patch_size: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> VisionTransformer:
image_size = kwargs.pop("image_size", 224)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VisionTransformer(
image_size=image_size,
patch_size=patch_size,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
**kwargs,
)
if weights:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1))
def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
weights = ViT_B_16_Weights.verify(weights)
return _vision_transformer(
patch_size=16,
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
weights=weights,
progress=progress,
**kwargs,
)
@handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.IMAGENET1K_V1))
def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
weights = ViT_B_32_Weights.verify(weights)
return _vision_transformer(
patch_size=32,
num_layers=12,
num_heads=12,
hidden_dim=768,
mlp_dim=3072,
weights=weights,
progress=progress,
**kwargs,
)
@handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.IMAGENET1K_V1))
def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
weights = ViT_L_16_Weights.verify(weights)
return _vision_transformer(
patch_size=16,
num_layers=24,
num_heads=16,
hidden_dim=1024,
mlp_dim=4096,
weights=weights,
progress=progress,
**kwargs,
)
@handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.IMAGENET1K_V1))
def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
weights = ViT_L_32_Weights.verify(weights)
return _vision_transformer(
patch_size=32,
num_layers=24,
num_heads=16,
hidden_dim=1024,
mlp_dim=4096,
weights=weights,
progress=progress,
**kwargs,
)
from torchvision.transforms import InterpolationMode, AutoAugmentPolicy # usort: skip
from . import functional # usort: skip from . import functional # usort: skip
from ._transform import Transform # usort: skip from ._transform import Transform # usort: skip
...@@ -21,11 +19,4 @@ from ._geometry import ( ...@@ -21,11 +19,4 @@ from ._geometry import (
) )
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import (
ObjectDetectionEval,
ImageClassificationEval,
SemanticSegmentationEval,
VideoClassificationEval,
OpticalFlowEval,
)
from ._type_conversion import DecodeImage, LabelToOneHot from ._type_conversion import DecodeImage, LabelToOneHot
...@@ -4,9 +4,10 @@ from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar, Un ...@@ -4,9 +4,10 @@ from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar, Un
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F from torchvision.prototype.transforms import Transform, functional as F
from torchvision.prototype.utils._internal import query_recursively from torchvision.prototype.utils._internal import query_recursively
from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms.functional import pil_to_tensor, to_pil_image, InterpolationMode
from ._utils import get_image_dimensions, is_simple_tensor from ._utils import get_image_dimensions, is_simple_tensor
......
...@@ -7,8 +7,8 @@ from typing import Any, Dict, List, Union, Sequence, Tuple, cast ...@@ -7,8 +7,8 @@ from typing import Any, Dict, List, Union, Sequence, Tuple, cast
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F from torchvision.prototype.transforms import Transform, functional as F
from torchvision.transforms.functional import pil_to_tensor from torchvision.transforms.functional import pil_to_tensor, InterpolationMode
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
from typing_extensions import Literal from typing_extensions import Literal
......
...@@ -4,9 +4,8 @@ from typing import Tuple, List, Optional, Sequence, Union ...@@ -4,9 +4,8 @@ from typing import Tuple, List, Optional, Sequence, Union
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import InterpolationMode
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix, InterpolationMode
from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil
......
import collections.abc import collections.abc
import difflib import difflib
import functools
import inspect
import io import io
import mmap import mmap
import os import os
import os.path import os.path
import platform import platform
import textwrap import textwrap
import warnings
from typing import ( from typing import (
Any, Any,
BinaryIO, BinaryIO,
...@@ -36,7 +33,6 @@ __all__ = [ ...@@ -36,7 +33,6 @@ __all__ = [
"FrozenMapping", "FrozenMapping",
"make_repr", "make_repr",
"FrozenBunch", "FrozenBunch",
"kwonly_to_pos_or_kw",
"fromfile", "fromfile",
"ReadOnlyTensorBuffer", "ReadOnlyTensorBuffer",
"apply_recursively", "apply_recursively",
...@@ -140,57 +136,6 @@ class FrozenBunch(FrozenMapping): ...@@ -140,57 +136,6 @@ class FrozenBunch(FrozenMapping):
return make_repr(type(self).__name__, self.items()) return make_repr(type(self).__name__, self.items())
def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]:
"""Decorates a function that uses keyword only parameters to also allow them being passed as positionals.
For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``:
.. code::
def old_fn(foo, bar, baz=None):
...
def new_fn(foo, *, bar, baz=None):
...
Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC
and at the same time warn the user of the deprecation, this decorator can be used:
.. code::
@kwonly_to_pos_or_kw
def new_fn(foo, *, bar, baz=None):
...
new_fn("foo", "bar, "baz")
"""
params = inspect.signature(fn).parameters
try:
keyword_only_start_idx = next(
idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY
)
except StopIteration:
raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None
keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:]
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> D:
args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:]
if keyword_only_args:
keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args))
warnings.warn(
f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional "
f"parameter(s) is deprecated. Please use keyword parameter(s) instead."
)
kwargs.update(keyword_only_kwargs)
return fn(*args, **kwargs)
return wrapper
def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray: def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray:
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable
return bytearray(file.read(-1 if count == -1 else count * item_size)) return bytearray(file.read(-1 if count == -1 else count * item_size))
......
from typing import Dict, Optional, Tuple """
This file is part of the private API. Please do not use directly these classes as they will be modified on
future versions without warning. The classes should be accessed only via the transforms argument of Weights.
"""
from typing import Optional, Tuple
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from ...transforms import functional as F, InterpolationMode from . import functional as F, InterpolationMode
__all__ = [ __all__ = [
"ObjectDetectionEval", "ObjectDetection",
"ImageClassificationEval", "ImageClassification",
"VideoClassificationEval", "VideoClassification",
"SemanticSegmentationEval", "SemanticSegmentation",
"OpticalFlowEval", "OpticalFlow",
] ]
class ObjectDetectionEval(nn.Module): class ObjectDetection(nn.Module):
def forward( def forward(self, img: Tensor) -> Tensor:
self, img: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if not isinstance(img, Tensor): if not isinstance(img, Tensor):
img = F.pil_to_tensor(img) img = F.pil_to_tensor(img)
return F.convert_image_dtype(img, torch.float), target return F.convert_image_dtype(img, torch.float)
class ImageClassificationEval(nn.Module): class ImageClassification(nn.Module):
def __init__( def __init__(
self, self,
*,
crop_size: int, crop_size: int,
resize_size: int = 256, resize_size: int = 256,
mean: Tuple[float, ...] = (0.485, 0.456, 0.406), mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
...@@ -50,9 +53,10 @@ class ImageClassificationEval(nn.Module): ...@@ -50,9 +53,10 @@ class ImageClassificationEval(nn.Module):
return img return img
class VideoClassificationEval(nn.Module): class VideoClassification(nn.Module):
def __init__( def __init__(
self, self,
*,
crop_size: Tuple[int, int], crop_size: Tuple[int, int],
resize_size: Tuple[int, int], resize_size: Tuple[int, int],
mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645), mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645),
...@@ -67,53 +71,61 @@ class VideoClassificationEval(nn.Module): ...@@ -67,53 +71,61 @@ class VideoClassificationEval(nn.Module):
self._interpolation = interpolation self._interpolation = interpolation
def forward(self, vid: Tensor) -> Tensor: def forward(self, vid: Tensor) -> Tensor:
vid = vid.permute(0, 3, 1, 2) # (T, H, W, C) => (T, C, H, W) need_squeeze = False
if vid.ndim < 5:
vid = vid.unsqueeze(dim=0)
need_squeeze = True
vid = vid.permute(0, 1, 4, 2, 3) # (N, T, H, W, C) => (N, T, C, H, W)
N, T, C, H, W = vid.shape
vid = vid.view(-1, C, H, W)
vid = F.resize(vid, self._size, interpolation=self._interpolation) vid = F.resize(vid, self._size, interpolation=self._interpolation)
vid = F.center_crop(vid, self._crop_size) vid = F.center_crop(vid, self._crop_size)
vid = F.convert_image_dtype(vid, torch.float) vid = F.convert_image_dtype(vid, torch.float)
vid = F.normalize(vid, mean=self._mean, std=self._std) vid = F.normalize(vid, mean=self._mean, std=self._std)
return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W) H, W = self._crop_size
vid = vid.view(N, T, C, H, W)
vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W)
if need_squeeze:
vid = vid.squeeze(dim=0)
return vid
class SemanticSegmentationEval(nn.Module): class SemanticSegmentation(nn.Module):
def __init__( def __init__(
self, self,
resize_size: int, *,
resize_size: Optional[int],
mean: Tuple[float, ...] = (0.485, 0.456, 0.406), mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
std: Tuple[float, ...] = (0.229, 0.224, 0.225), std: Tuple[float, ...] = (0.229, 0.224, 0.225),
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation_target: InterpolationMode = InterpolationMode.NEAREST,
) -> None: ) -> None:
super().__init__() super().__init__()
self._size = [resize_size] self._size = [resize_size] if resize_size is not None else None
self._mean = list(mean) self._mean = list(mean)
self._std = list(std) self._std = list(std)
self._interpolation = interpolation self._interpolation = interpolation
self._interpolation_target = interpolation_target
def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: def forward(self, img: Tensor) -> Tensor:
img = F.resize(img, self._size, interpolation=self._interpolation) if isinstance(self._size, list):
img = F.resize(img, self._size, interpolation=self._interpolation)
if not isinstance(img, Tensor): if not isinstance(img, Tensor):
img = F.pil_to_tensor(img) img = F.pil_to_tensor(img)
img = F.convert_image_dtype(img, torch.float) img = F.convert_image_dtype(img, torch.float)
img = F.normalize(img, mean=self._mean, std=self._std) img = F.normalize(img, mean=self._mean, std=self._std)
if target: return img
target = F.resize(target, self._size, interpolation=self._interpolation_target)
if not isinstance(target, Tensor):
target = F.pil_to_tensor(target)
target = target.squeeze(0).to(torch.int64)
return img, target
class OpticalFlowEval(nn.Module):
def forward(
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor]
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
img1, img2, flow, valid_flow_mask = self._pil_or_numpy_to_tensor(img1, img2, flow, valid_flow_mask) class OpticalFlow(nn.Module):
def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]:
if not isinstance(img1, Tensor):
img1 = F.pil_to_tensor(img1)
if not isinstance(img2, Tensor):
img2 = F.pil_to_tensor(img2)
img1 = F.convert_image_dtype(img1, torch.float32) img1 = F.convert_image_dtype(img1, torch.float)
img2 = F.convert_image_dtype(img2, torch.float32) img2 = F.convert_image_dtype(img2, torch.float)
# map [0, 1] into [-1, 1] # map [0, 1] into [-1, 1]
img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
...@@ -122,19 +134,4 @@ class OpticalFlowEval(nn.Module): ...@@ -122,19 +134,4 @@ class OpticalFlowEval(nn.Module):
img1 = img1.contiguous() img1 = img1.contiguous()
img2 = img2.contiguous() img2 = img2.contiguous()
return img1, img2, flow, valid_flow_mask return img1, img2
def _pil_or_numpy_to_tensor(
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor]
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
if not isinstance(img1, Tensor):
img1 = F.pil_to_tensor(img1)
if not isinstance(img2, Tensor):
img2 = F.pil_to_tensor(img2)
if flow is not None and not isinstance(flow, Tensor):
flow = torch.from_numpy(flow)
if valid_flow_mask is not None and not isinstance(valid_flow_mask, Tensor):
valid_flow_mask = torch.from_numpy(valid_flow_mask)
return img1, img2, flow, valid_flow_mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment