"git@developer.sourcefind.cn:OpenDAS/torchani.git" did not exist on "14a62dc466ed503cfb16cb4c180de17541f17aed"
Unverified Commit 272e080c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add initial chunk of prototype transforms (#4861)

* add initial chunk of prototype transforms

* fix tests

* add error message

* fix more imports

* add explicit no-ops

* add test for no-ops

* cleanup
parent 57e6e302
...@@ -2,10 +2,10 @@ import warnings ...@@ -2,10 +2,10 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.mnasnet import MNASNet from ...models.mnasnet import MNASNet
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
......
...@@ -2,10 +2,10 @@ import warnings ...@@ -2,10 +2,10 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv2 import MobileNetV2 from ...models.mobilenetv2 import MobileNetV2
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
......
...@@ -2,10 +2,10 @@ import warnings ...@@ -2,10 +2,10 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional, List from typing import Any, Optional, List
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
......
...@@ -2,6 +2,7 @@ import warnings ...@@ -2,6 +2,7 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional, Union from typing import Any, Optional, Union
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.googlenet import ( from ....models.quantization.googlenet import (
...@@ -9,7 +10,6 @@ from ....models.quantization.googlenet import ( ...@@ -9,7 +10,6 @@ from ....models.quantization.googlenet import (
_replace_relu, _replace_relu,
quantize_model, quantize_model,
) )
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from ..googlenet import GoogLeNetWeights from ..googlenet import GoogLeNetWeights
......
...@@ -2,6 +2,7 @@ import warnings ...@@ -2,6 +2,7 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional, Union from typing import Any, Optional, Union
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.inception import ( from ....models.quantization.inception import (
...@@ -9,7 +10,6 @@ from ....models.quantization.inception import ( ...@@ -9,7 +10,6 @@ from ....models.quantization.inception import (
_replace_relu, _replace_relu,
quantize_model, quantize_model,
) )
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from ..inception import InceptionV3Weights from ..inception import InceptionV3Weights
......
...@@ -2,6 +2,7 @@ import warnings ...@@ -2,6 +2,7 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional, Union from typing import Any, Optional, Union
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.mobilenetv2 import ( from ....models.quantization.mobilenetv2 import (
...@@ -10,7 +11,6 @@ from ....models.quantization.mobilenetv2 import ( ...@@ -10,7 +11,6 @@ from ....models.quantization.mobilenetv2 import (
_replace_relu, _replace_relu,
quantize_model, quantize_model,
) )
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from ..mobilenetv2 import MobileNetV2Weights from ..mobilenetv2 import MobileNetV2Weights
......
...@@ -3,6 +3,7 @@ from functools import partial ...@@ -3,6 +3,7 @@ from functools import partial
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
import torch import torch
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.mobilenetv3 import ( from ....models.quantization.mobilenetv3 import (
...@@ -11,7 +12,6 @@ from ....models.quantization.mobilenetv3 import ( ...@@ -11,7 +12,6 @@ from ....models.quantization.mobilenetv3 import (
QuantizableMobileNetV3, QuantizableMobileNetV3,
_replace_relu, _replace_relu,
) )
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf
......
...@@ -2,6 +2,7 @@ import warnings ...@@ -2,6 +2,7 @@ import warnings
from functools import partial from functools import partial
from typing import Any, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.resnet import ( from ....models.quantization.resnet import (
...@@ -11,7 +12,6 @@ from ....models.quantization.resnet import ( ...@@ -11,7 +12,6 @@ from ....models.quantization.resnet import (
_replace_relu, _replace_relu,
quantize_model, quantize_model,
) )
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from ..resnet import ResNet18Weights, ResNet50Weights, ResNeXt101_32x8dWeights from ..resnet import ResNet18Weights, ResNet50Weights, ResNeXt101_32x8dWeights
......
...@@ -2,6 +2,7 @@ import warnings ...@@ -2,6 +2,7 @@ import warnings
from functools import partial from functools import partial
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.shufflenetv2 import ( from ....models.quantization.shufflenetv2 import (
...@@ -9,7 +10,6 @@ from ....models.quantization.shufflenetv2 import ( ...@@ -9,7 +10,6 @@ from ....models.quantization.shufflenetv2 import (
_replace_relu, _replace_relu,
quantize_model, quantize_model,
) )
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES from .._meta import _IMAGENET_CATEGORIES
from ..shufflenetv2 import ShuffleNetV2_x0_5Weights, ShuffleNetV2_x1_0Weights from ..shufflenetv2 import ShuffleNetV2_x0_5Weights, ShuffleNetV2_x1_0Weights
......
...@@ -3,10 +3,10 @@ from functools import partial ...@@ -3,10 +3,10 @@ from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torch import nn from torch import nn
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.regnet import RegNet, BlockParams from ...models.regnet import RegNet, BlockParams
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
......
...@@ -2,10 +2,10 @@ import warnings ...@@ -2,10 +2,10 @@ import warnings
from functools import partial from functools import partial
from typing import Any, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.resnet import BasicBlock, Bottleneck, ResNet from ...models.resnet import BasicBlock, Bottleneck, ResNet
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
......
...@@ -2,10 +2,10 @@ import warnings ...@@ -2,10 +2,10 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import VocEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet
from ...transforms.presets import VocEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _VOC_CATEGORIES from .._meta import _VOC_CATEGORIES
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
......
...@@ -2,10 +2,10 @@ import warnings ...@@ -2,10 +2,10 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import VocEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.fcn import FCN, _fcn_resnet from ....models.segmentation.fcn import FCN, _fcn_resnet
from ...transforms.presets import VocEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _VOC_CATEGORIES from .._meta import _VOC_CATEGORIES
from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101 from ..resnet import ResNet50Weights, ResNet101Weights, resnet50, resnet101
......
...@@ -2,10 +2,10 @@ import warnings ...@@ -2,10 +2,10 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import VocEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3
from ...transforms.presets import VocEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _VOC_CATEGORIES from .._meta import _VOC_CATEGORIES
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
......
...@@ -2,10 +2,10 @@ import warnings ...@@ -2,10 +2,10 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.shufflenetv2 import ShuffleNetV2 from ...models.shufflenetv2 import ShuffleNetV2
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
......
...@@ -2,10 +2,10 @@ import warnings ...@@ -2,10 +2,10 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.squeezenet import SqueezeNet from ...models.squeezenet import SqueezeNet
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
......
...@@ -2,10 +2,10 @@ import warnings ...@@ -2,10 +2,10 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.vgg import VGG, make_layers, cfgs from ...models.vgg import VGG, make_layers, cfgs
from ..transforms.presets import ImageNetEval
from ._api import Weights, WeightEntry from ._api import Weights, WeightEntry
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
......
...@@ -3,6 +3,7 @@ from functools import partial ...@@ -3,6 +3,7 @@ from functools import partial
from typing import Any, Callable, List, Optional, Sequence, Type, Union from typing import Any, Callable, List, Optional, Sequence, Type, Union
from torch import nn from torch import nn
from torchvision.prototype.transforms import Kinect400Eval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.video.resnet import ( from ....models.video.resnet import (
...@@ -15,7 +16,6 @@ from ....models.video.resnet import ( ...@@ -15,7 +16,6 @@ from ....models.video.resnet import (
R2Plus1dStem, R2Plus1dStem,
VideoResNet, VideoResNet,
) )
from ...transforms.presets import Kinect400Eval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _KINETICS400_CATEGORIES from .._meta import _KINETICS400_CATEGORIES
......
from .presets import * from ._transform import Transform
from ._container import Compose, RandomApply, RandomChoice, RandomOrder # usort: skip
from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop
from ._misc import Identity, Normalize
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval
from typing import Any, List
import torch
from torch import nn
from torchvision.prototype.transforms import Transform
class ContainerTransform(nn.Module):
def supports(self, obj: Any) -> bool:
raise NotImplementedError()
def forward(self, *inputs: Any) -> Any:
raise NotImplementedError()
def _make_repr(self, lines: List[str]) -> str:
extra_repr = self.extra_repr()
if extra_repr:
lines = [self.extra_repr(), *lines]
head = f"{type(self).__name__}("
tail = ")"
body = [f" {line.rstrip()}" for line in lines]
return "\n".join([head, *body, tail])
class WrapperTransform(ContainerTransform):
def __init__(self, transform: Transform):
super().__init__()
self._transform = transform
def supports(self, obj: Any) -> bool:
return self._transform.supports(obj)
def __repr__(self) -> str:
return self._make_repr(repr(self._transform).splitlines())
class MultiTransform(ContainerTransform):
def __init__(self, *transforms: Transform) -> None:
super().__init__()
self._transforms = transforms
def supports(self, obj: Any) -> bool:
return all(transform.supports(obj) for transform in self._transforms)
def __repr__(self) -> str:
lines = []
for idx, transform in enumerate(self._transforms):
partial_lines = repr(transform).splitlines()
lines.append(f"({idx:d}): {partial_lines[0]}")
lines.extend(partial_lines[1:])
return self._make_repr(lines)
class Compose(MultiTransform):
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for transform in self._transforms:
sample = transform(sample)
return sample
class RandomApply(WrapperTransform):
def __init__(self, transform: Transform, *, p: float = 0.5) -> None:
super().__init__(transform)
self._p = p
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if float(torch.rand(())) < self._p:
return sample
return self._transform(sample)
def extra_repr(self) -> str:
return f"p={self._p}"
class RandomChoice(MultiTransform):
def forward(self, *inputs: Any) -> Any:
idx = int(torch.randint(len(self._transforms), size=()))
transform = self._transforms[idx]
return transform(*inputs)
class RandomOrder(MultiTransform):
def forward(self, *inputs: Any) -> Any:
for idx in torch.randperm(len(self._transforms)):
transform = self._transforms[idx]
inputs = transform(*inputs)
return inputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment