Unverified Commit d57f929d authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Move Permute layer to ops. (#6055)

parent 77cad127
...@@ -87,8 +87,9 @@ TorchVision provides commonly used building blocks as layers: ...@@ -87,8 +87,9 @@ TorchVision provides commonly used building blocks as layers:
DeformConv2d DeformConv2d
DropBlock2d DropBlock2d
DropBlock3d DropBlock3d
MLP
FrozenBatchNorm2d FrozenBatchNorm2d
MLP
Permute
SqueezeExcitation SqueezeExcitation
StochasticDepth StochasticDepth
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn import functional as F from torch.nn import functional as F
from ..ops.misc import Conv2dNormActivation from ..ops.misc import Conv2dNormActivation, Permute
from ..ops.stochastic_depth import StochasticDepth from ..ops.stochastic_depth import StochasticDepth
from ..transforms._presets import ImageClassification from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
...@@ -35,15 +35,6 @@ class LayerNorm2d(nn.LayerNorm): ...@@ -35,15 +35,6 @@ class LayerNorm2d(nn.LayerNorm):
return x return x
class Permute(nn.Module):
def __init__(self, dims: List[int]):
super().__init__()
self.dims = dims
def forward(self, x):
return torch.permute(x, self.dims)
class CNBlock(nn.Module): class CNBlock(nn.Module):
def __init__( def __init__(
self, self,
......
...@@ -5,14 +5,13 @@ import torch ...@@ -5,14 +5,13 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, Tensor from torch import nn, Tensor
from ..ops.misc import MLP from ..ops.misc import MLP, Permute
from ..ops.stochastic_depth import StochasticDepth from ..ops.stochastic_depth import StochasticDepth
from ..transforms._presets import ImageClassification, InterpolationMode from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param from ._utils import _ovewrite_named_param
from .convnext import Permute # TODO: move Permute on ops
__all__ = [ __all__ = [
......
...@@ -19,7 +19,7 @@ from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d ...@@ -19,7 +19,7 @@ from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d
from .feature_pyramid_network import FeaturePyramidNetwork from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss from .focal_loss import sigmoid_focal_loss
from .giou_loss import generalized_box_iou_loss from .giou_loss import generalized_box_iou_loss
from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation, MLP from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation, MLP, Permute
from .poolers import MultiScaleRoIAlign from .poolers import MultiScaleRoIAlign
from .ps_roi_align import ps_roi_align, PSRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign
from .ps_roi_pool import ps_roi_pool, PSRoIPool from .ps_roi_pool import ps_roi_pool, PSRoIPool
...@@ -62,6 +62,7 @@ __all__ = [ ...@@ -62,6 +62,7 @@ __all__ = [
"Conv3dNormActivation", "Conv3dNormActivation",
"SqueezeExcitation", "SqueezeExcitation",
"MLP", "MLP",
"Permute",
"generalized_box_iou_loss", "generalized_box_iou_loss",
"distance_box_iou_loss", "distance_box_iou_loss",
"complete_box_iou_loss", "complete_box_iou_loss",
......
...@@ -10,7 +10,6 @@ from ..utils import _log_api_usage_once ...@@ -10,7 +10,6 @@ from ..utils import _log_api_usage_once
interpolate = torch.nn.functional.interpolate interpolate = torch.nn.functional.interpolate
# This is not in nn
class FrozenBatchNorm2d(torch.nn.Module): class FrozenBatchNorm2d(torch.nn.Module):
""" """
BatchNorm2d where the batch statistics and the affine parameters are fixed BatchNorm2d where the batch statistics and the affine parameters are fixed
...@@ -297,3 +296,18 @@ class MLP(torch.nn.Sequential): ...@@ -297,3 +296,18 @@ class MLP(torch.nn.Sequential):
super().__init__(*layers) super().__init__(*layers)
_log_api_usage_once(self) _log_api_usage_once(self)
class Permute(torch.nn.Module):
"""This module returns a view of the tensor input with its dimensions permuted.
Args:
dims (List[int]): The desired ordering of dimensions
"""
def __init__(self, dims: List[int]):
super().__init__()
self.dims = dims
def forward(self, x: Tensor) -> Tensor:
return torch.permute(x, self.dims)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment