Unverified Commit 2d3e42fc authored by tripleMu's avatar tripleMu Committed by GitHub
Browse files

Add type hints for mmcv/ops (#1995)



* Merge Master

* Add typehint in mmcv/ops/*

* Fix

* Update mmcv/ops/roi_align.py
Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>

* Fix

* Fix

* Fix

* Update mmcv/ops/riroi_align_rotated.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/ops/riroi_align_rotated.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* remove type hints of all symbolic methods

* remove type hints of all symbolic methods

* minor refinement

* minor refinement

* minor fix
Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarzhouzaida <zhouzaida@163.com>
parent 3dd2a21b
# Modified from https://github.com/hszhao/semseg/blob/master/lib/psa # Modified from https://github.com/hszhao/semseg/blob/master/lib/psa
from typing import Optional, Tuple
import torch
from torch import nn from torch import nn
from torch.autograd import Function from torch.autograd import Function
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
...@@ -20,7 +23,8 @@ class PSAMaskFunction(Function): ...@@ -20,7 +23,8 @@ class PSAMaskFunction(Function):
mask_size_i=mask_size) mask_size_i=mask_size)
@staticmethod @staticmethod
def forward(ctx, input, psa_type, mask_size): def forward(ctx, input: torch.Tensor, psa_type: str,
mask_size: int) -> torch.Tensor:
ctx.psa_type = psa_type ctx.psa_type = psa_type
ctx.mask_size = _pair(mask_size) ctx.mask_size = _pair(mask_size)
ctx.save_for_backward(input) ctx.save_for_backward(input)
...@@ -45,7 +49,9 @@ class PSAMaskFunction(Function): ...@@ -45,7 +49,9 @@ class PSAMaskFunction(Function):
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[torch.Tensor, None, None, None]:
input = ctx.saved_tensors[0] input = ctx.saved_tensors[0]
psa_type = ctx.psa_type psa_type = ctx.psa_type
h_mask, w_mask = ctx.mask_size h_mask, w_mask = ctx.mask_size
...@@ -71,7 +77,7 @@ psa_mask = PSAMaskFunction.apply ...@@ -71,7 +77,7 @@ psa_mask = PSAMaskFunction.apply
class PSAMask(nn.Module): class PSAMask(nn.Module):
def __init__(self, psa_type, mask_size=None): def __init__(self, psa_type: str, mask_size: Optional[tuple] = None):
super().__init__() super().__init__()
assert psa_type in ['collect', 'distribute'] assert psa_type in ['collect', 'distribute']
if psa_type == 'collect': if psa_type == 'collect':
...@@ -82,7 +88,7 @@ class PSAMask(nn.Module): ...@@ -82,7 +88,7 @@ class PSAMask(nn.Module):
self.mask_size = mask_size self.mask_size = mask_size
self.psa_type = psa_type self.psa_type = psa_type
def forward(self, input): def forward(self, input: torch.Tensor) -> torch.Tensor:
return psa_mask(input, self.psa_type_enum, self.mask_size) return psa_mask(input, self.psa_type_enum, self.mask_size)
def __repr__(self): def __repr__(self):
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Tuple, Union
import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import Function from torch.autograd import Function
...@@ -11,14 +14,14 @@ ext_module = ext_loader.load_ext( ...@@ -11,14 +14,14 @@ ext_module = ext_loader.load_ext(
class RiRoIAlignRotatedFunction(Function): class RiRoIAlignRotatedFunction(Function):
@staticmethod @staticmethod
def forward(ctx, def forward(ctx: Any,
features, features: torch.Tensor,
rois, rois: torch.Tensor,
out_size, out_size: Union[int, tuple],
spatial_scale, spatial_scale: float,
num_samples=0, num_samples: int = 0,
num_orientations=8, num_orientations: int = 8,
clockwise=False): clockwise: bool = False) -> torch.Tensor:
if isinstance(out_size, int): if isinstance(out_size, int):
out_h = out_size out_h = out_size
out_w = out_size out_w = out_size
...@@ -54,7 +57,9 @@ class RiRoIAlignRotatedFunction(Function): ...@@ -54,7 +57,9 @@ class RiRoIAlignRotatedFunction(Function):
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(
ctx: Any, grad_output: torch.Tensor
) -> Optional[Tuple[torch.Tensor, None, None, None, None, None, None]]:
feature_size = ctx.feature_size feature_size = ctx.feature_size
spatial_scale = ctx.spatial_scale spatial_scale = ctx.spatial_scale
num_orientations = ctx.num_orientations num_orientations = ctx.num_orientations
...@@ -67,7 +72,7 @@ class RiRoIAlignRotatedFunction(Function): ...@@ -67,7 +72,7 @@ class RiRoIAlignRotatedFunction(Function):
out_w = grad_output.size(3) out_w = grad_output.size(3)
out_h = grad_output.size(2) out_h = grad_output.size(2)
grad_input = grad_rois = None grad_input = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_input = rois.new_zeros(batch_size, num_channels, feature_h, grad_input = rois.new_zeros(batch_size, num_channels, feature_h,
...@@ -83,7 +88,8 @@ class RiRoIAlignRotatedFunction(Function): ...@@ -83,7 +88,8 @@ class RiRoIAlignRotatedFunction(Function):
num_orientations=num_orientations, num_orientations=num_orientations,
clockwise=clockwise) clockwise=clockwise)
return grad_input, grad_rois, None, None, None, None, None return grad_input, None, None, None, None, None, None
return None
riroi_align_rotated = RiRoIAlignRotatedFunction.apply riroi_align_rotated = RiRoIAlignRotatedFunction.apply
...@@ -111,11 +117,11 @@ class RiRoIAlignRotated(nn.Module): ...@@ -111,11 +117,11 @@ class RiRoIAlignRotated(nn.Module):
""" """
def __init__(self, def __init__(self,
out_size, out_size: tuple,
spatial_scale, spatial_scale: float,
num_samples=0, num_samples: int = 0,
num_orientations=8, num_orientations: int = 8,
clockwise=False): clockwise: bool = False):
super().__init__() super().__init__()
self.out_size = out_size self.out_size = out_size
...@@ -124,7 +130,8 @@ class RiRoIAlignRotated(nn.Module): ...@@ -124,7 +130,8 @@ class RiRoIAlignRotated(nn.Module):
self.num_orientations = int(num_orientations) self.num_orientations = int(num_orientations)
self.clockwise = clockwise self.clockwise = clockwise
def forward(self, features, rois): def forward(self, features: torch.Tensor,
rois: torch.Tensor) -> torch.Tensor:
return RiRoIAlignRotatedFunction.apply(features, rois, self.out_size, return RiRoIAlignRotatedFunction.apply(features, rois, self.out_size,
self.spatial_scale, self.spatial_scale,
self.num_samples, self.num_samples,
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import Function from torch.autograd import Function
...@@ -62,14 +64,14 @@ class RoIAlignFunction(Function): ...@@ -62,14 +64,14 @@ class RoIAlignFunction(Function):
mode_s=pool_mode) mode_s=pool_mode)
@staticmethod @staticmethod
def forward(ctx, def forward(ctx: Any,
input, input: torch.Tensor,
rois, rois: torch.Tensor,
output_size, output_size: int,
spatial_scale=1.0, spatial_scale: float = 1.0,
sampling_ratio=0, sampling_ratio: int = 0,
pool_mode='avg', pool_mode: str = 'avg',
aligned=True): aligned: bool = True) -> torch.Tensor:
ctx.output_size = _pair(output_size) ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale ctx.spatial_scale = spatial_scale
ctx.sampling_ratio = sampling_ratio ctx.sampling_ratio = sampling_ratio
...@@ -108,7 +110,7 @@ class RoIAlignFunction(Function): ...@@ -108,7 +110,7 @@ class RoIAlignFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
def backward(ctx, grad_output): def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
rois, argmax_y, argmax_x = ctx.saved_tensors rois, argmax_y, argmax_x = ctx.saved_tensors
grad_input = grad_output.new_zeros(ctx.input_shape) grad_input = grad_output.new_zeros(ctx.input_shape)
# complex head architecture may cause grad_output uncontiguous. # complex head architecture may cause grad_output uncontiguous.
...@@ -175,12 +177,12 @@ class RoIAlign(nn.Module): ...@@ -175,12 +177,12 @@ class RoIAlign(nn.Module):
}, },
cls_name='RoIAlign') cls_name='RoIAlign')
def __init__(self, def __init__(self,
output_size, output_size: tuple,
spatial_scale=1.0, spatial_scale: float = 1.0,
sampling_ratio=0, sampling_ratio: int = 0,
pool_mode='avg', pool_mode: str = 'avg',
aligned=True, aligned: bool = True,
use_torchvision=False): use_torchvision: bool = False):
super().__init__() super().__init__()
self.output_size = _pair(output_size) self.output_size = _pair(output_size)
...@@ -190,7 +192,7 @@ class RoIAlign(nn.Module): ...@@ -190,7 +192,7 @@ class RoIAlign(nn.Module):
self.aligned = aligned self.aligned = aligned
self.use_torchvision = use_torchvision self.use_torchvision = use_torchvision
def forward(self, input, rois): def forward(self, input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
""" """
Args: Args:
input: NCHW images input: NCHW images
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Tuple, Union
import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import Function from torch.autograd import Function
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
...@@ -37,14 +40,14 @@ class RoIAlignRotatedFunction(Function): ...@@ -37,14 +40,14 @@ class RoIAlignRotatedFunction(Function):
clockwise_i=clockwise) clockwise_i=clockwise)
@staticmethod @staticmethod
def forward(ctx, def forward(ctx: Any,
input, input: torch.Tensor,
rois, rois: torch.Tensor,
output_size, output_size: Union[int, tuple],
spatial_scale, spatial_scale: float,
sampling_ratio=0, sampling_ratio: int = 0,
aligned=True, aligned: bool = True,
clockwise=False): clockwise: bool = False) -> torch.Tensor:
ctx.output_size = _pair(output_size) ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale ctx.spatial_scale = spatial_scale
ctx.sampling_ratio = sampling_ratio ctx.sampling_ratio = sampling_ratio
...@@ -71,7 +74,10 @@ class RoIAlignRotatedFunction(Function): ...@@ -71,7 +74,10 @@ class RoIAlignRotatedFunction(Function):
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(
ctx: Any, grad_output: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], None, None,
None, None, None]:
feature_size = ctx.feature_size feature_size = ctx.feature_size
rois = ctx.saved_tensors[0] rois = ctx.saved_tensors[0]
assert feature_size is not None assert feature_size is not None
...@@ -151,11 +157,11 @@ class RoIAlignRotated(nn.Module): ...@@ -151,11 +157,11 @@ class RoIAlignRotated(nn.Module):
}, },
cls_name='RoIAlignRotated') cls_name='RoIAlignRotated')
def __init__(self, def __init__(self,
output_size, output_size: Union[int, tuple],
spatial_scale, spatial_scale: float,
sampling_ratio=0, sampling_ratio: int = 0,
aligned=True, aligned: bool = True,
clockwise=False): clockwise: bool = False):
super().__init__() super().__init__()
self.output_size = _pair(output_size) self.output_size = _pair(output_size)
...@@ -164,7 +170,7 @@ class RoIAlignRotated(nn.Module): ...@@ -164,7 +170,7 @@ class RoIAlignRotated(nn.Module):
self.aligned = aligned self.aligned = aligned
self.clockwise = clockwise self.clockwise = clockwise
def forward(self, input, rois): def forward(self, input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
return RoIAlignRotatedFunction.apply(input, rois, self.output_size, return RoIAlignRotatedFunction.apply(input, rois, self.output_size,
self.spatial_scale, self.spatial_scale,
self.sampling_ratio, self.aligned, self.sampling_ratio, self.aligned,
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import Function from torch.autograd import Function
...@@ -23,7 +25,11 @@ class RoIPoolFunction(Function): ...@@ -23,7 +25,11 @@ class RoIPoolFunction(Function):
spatial_scale_f=spatial_scale) spatial_scale_f=spatial_scale)
@staticmethod @staticmethod
def forward(ctx, input, rois, output_size, spatial_scale=1.0): def forward(ctx: Any,
input: torch.Tensor,
rois: torch.Tensor,
output_size: Union[int, tuple],
spatial_scale: float = 1.0) -> torch.Tensor:
ctx.output_size = _pair(output_size) ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale ctx.spatial_scale = spatial_scale
ctx.input_shape = input.size() ctx.input_shape = input.size()
...@@ -49,7 +55,9 @@ class RoIPoolFunction(Function): ...@@ -49,7 +55,9 @@ class RoIPoolFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
def backward(ctx, grad_output): def backward(
ctx: Any, grad_output: torch.Tensor
) -> Tuple[torch.Tensor, None, None, None]:
rois, argmax = ctx.saved_tensors rois, argmax = ctx.saved_tensors
grad_input = grad_output.new_zeros(ctx.input_shape) grad_input = grad_output.new_zeros(ctx.input_shape)
...@@ -70,13 +78,15 @@ roi_pool = RoIPoolFunction.apply ...@@ -70,13 +78,15 @@ roi_pool = RoIPoolFunction.apply
class RoIPool(nn.Module): class RoIPool(nn.Module):
def __init__(self, output_size, spatial_scale=1.0): def __init__(self,
output_size: Union[int, tuple],
spatial_scale: float = 1.0):
super().__init__() super().__init__()
self.output_size = _pair(output_size) self.output_size = _pair(output_size)
self.spatial_scale = float(spatial_scale) self.spatial_scale = float(spatial_scale)
def forward(self, input, rois): def forward(self, input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
return roi_pool(input, rois, self.output_size, self.spatial_scale) return roi_pool(input, rois, self.output_size, self.spatial_scale)
def __repr__(self): def __repr__(self):
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Tuple, Union
import torch import torch
from torch import nn as nn from torch import nn as nn
from torch.autograd import Function from torch.autograd import Function
...@@ -25,7 +27,10 @@ class RoIAwarePool3d(nn.Module): ...@@ -25,7 +27,10 @@ class RoIAwarePool3d(nn.Module):
Default: 'max'. Default: 'max'.
""" """
def __init__(self, out_size, max_pts_per_voxel=128, mode='max'): def __init__(self,
out_size: Union[int, tuple],
max_pts_per_voxel: int = 128,
mode: str = 'max'):
super().__init__() super().__init__()
self.out_size = out_size self.out_size = out_size
...@@ -34,7 +39,8 @@ class RoIAwarePool3d(nn.Module): ...@@ -34,7 +39,8 @@ class RoIAwarePool3d(nn.Module):
pool_mapping = {'max': 0, 'avg': 1} pool_mapping = {'max': 0, 'avg': 1}
self.mode = pool_mapping[mode] self.mode = pool_mapping[mode]
def forward(self, rois, pts, pts_feature): def forward(self, rois: torch.Tensor, pts: torch.Tensor,
pts_feature: torch.Tensor) -> torch.Tensor:
""" """
Args: Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate, rois (torch.Tensor): [N, 7], in LiDAR coordinate,
...@@ -55,8 +61,9 @@ class RoIAwarePool3d(nn.Module): ...@@ -55,8 +61,9 @@ class RoIAwarePool3d(nn.Module):
class RoIAwarePool3dFunction(Function): class RoIAwarePool3dFunction(Function):
@staticmethod @staticmethod
def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel, def forward(ctx: Any, rois: torch.Tensor, pts: torch.Tensor,
mode): pts_feature: torch.Tensor, out_size: Union[int, tuple],
max_pts_per_voxel: int, mode: int) -> torch.Tensor:
""" """
Args: Args:
rois (torch.Tensor): [N, 7], in LiDAR coordinate, rois (torch.Tensor): [N, 7], in LiDAR coordinate,
...@@ -108,7 +115,9 @@ class RoIAwarePool3dFunction(Function): ...@@ -108,7 +115,9 @@ class RoIAwarePool3dFunction(Function):
return pooled_features return pooled_features
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(
ctx: Any, grad_out: torch.Tensor
) -> Tuple[None, None, torch.Tensor, None, None, None]:
ret = ctx.roiaware_pool3d_for_backward ret = ctx.roiaware_pool3d_for_backward
pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret
......
from typing import Any, Tuple
import torch
from torch import nn as nn from torch import nn as nn
from torch.autograd import Function from torch.autograd import Function
...@@ -17,11 +20,12 @@ class RoIPointPool3d(nn.Module): ...@@ -17,11 +20,12 @@ class RoIPointPool3d(nn.Module):
Default: 512. Default: 512.
""" """
def __init__(self, num_sampled_points=512): def __init__(self, num_sampled_points: int = 512):
super().__init__() super().__init__()
self.num_sampled_points = num_sampled_points self.num_sampled_points = num_sampled_points
def forward(self, points, point_features, boxes3d): def forward(self, points: torch.Tensor, point_features: torch.Tensor,
boxes3d: torch.Tensor) -> Tuple[torch.Tensor]:
""" """
Args: Args:
points (torch.Tensor): Input points whose shape is (B, N, C). points (torch.Tensor): Input points whose shape is (B, N, C).
...@@ -41,7 +45,13 @@ class RoIPointPool3d(nn.Module): ...@@ -41,7 +45,13 @@ class RoIPointPool3d(nn.Module):
class RoIPointPool3dFunction(Function): class RoIPointPool3dFunction(Function):
@staticmethod @staticmethod
def forward(ctx, points, point_features, boxes3d, num_sampled_points=512): def forward(
ctx: Any,
points: torch.Tensor,
point_features: torch.Tensor,
boxes3d: torch.Tensor,
num_sampled_points: int = 512
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
points (torch.Tensor): Input points whose shape is (B, N, C). points (torch.Tensor): Input points whose shape is (B, N, C).
...@@ -73,5 +83,5 @@ class RoIPointPool3dFunction(Function): ...@@ -73,5 +83,5 @@ class RoIPointPool3dFunction(Function):
return pooled_features, pooled_empty_flag return pooled_features, pooled_empty_flag
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx: Any, grad_out: torch.Tensor) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Any
import torch import torch
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
...@@ -31,7 +33,8 @@ class RotatedFeatureAlignFunction(Function): ...@@ -31,7 +33,8 @@ class RotatedFeatureAlignFunction(Function):
points_i=points) points_i=points)
@staticmethod @staticmethod
def forward(ctx, features, best_rbboxes, spatial_scale, points): def forward(ctx: Any, features: torch.Tensor, best_rbboxes: torch.Tensor,
spatial_scale: float, points: int) -> torch.Tensor:
""" """
Args: Args:
features (torch.Tensor): Input features with shape [N,C,H,W]. features (torch.Tensor): Input features with shape [N,C,H,W].
...@@ -60,7 +63,7 @@ class RotatedFeatureAlignFunction(Function): ...@@ -60,7 +63,7 @@ class RotatedFeatureAlignFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
def backward(ctx, grad_output): def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
""" """
Args: Args:
grad_output (torch.Tensor): The gradiant of output features grad_output (torch.Tensor): The gradiant of output features
...@@ -84,9 +87,9 @@ class RotatedFeatureAlignFunction(Function): ...@@ -84,9 +87,9 @@ class RotatedFeatureAlignFunction(Function):
return grad_input, None, None, None return grad_input, None, None, None
def rotated_feature_align(features, def rotated_feature_align(features: torch.Tensor,
best_rbboxes, best_rbboxes: torch.Tensor,
spatial_scale=1 / 8, spatial_scale: float = 1 / 8,
points=1): points: int = 1) -> torch.Tensor:
return RotatedFeatureAlignFunction.apply(features, best_rbboxes, return RotatedFeatureAlignFunction.apply(features, best_rbboxes,
spatial_scale, points) spatial_scale, points)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
...@@ -14,7 +16,10 @@ ext_module = ext_loader.load_ext( ...@@ -14,7 +16,10 @@ ext_module = ext_loader.load_ext(
class _DynamicScatter(Function): class _DynamicScatter(Function):
@staticmethod @staticmethod
def forward(ctx, feats, coors, reduce_type='max'): def forward(ctx: Any,
feats: torch.Tensor,
coors: torch.Tensor,
reduce_type: str = 'max') -> Tuple[torch.Tensor, torch.Tensor]:
"""convert kitti points(N, >=3) to voxels. """convert kitti points(N, >=3) to voxels.
Args: Args:
...@@ -42,7 +47,9 @@ class _DynamicScatter(Function): ...@@ -42,7 +47,9 @@ class _DynamicScatter(Function):
return voxel_feats, voxel_coors return voxel_feats, voxel_coors
@staticmethod @staticmethod
def backward(ctx, grad_voxel_feats, grad_voxel_coors=None): def backward(ctx: Any,
grad_voxel_feats: torch.Tensor,
grad_voxel_coors: Optional[torch.Tensor] = None) -> tuple:
(feats, voxel_feats, point2voxel_map, (feats, voxel_feats, point2voxel_map,
voxel_points_count) = ctx.saved_tensors voxel_points_count) = ctx.saved_tensors
grad_feats = torch.zeros_like(feats) grad_feats = torch.zeros_like(feats)
...@@ -73,14 +80,17 @@ class DynamicScatter(nn.Module): ...@@ -73,14 +80,17 @@ class DynamicScatter(nn.Module):
into voxel. into voxel.
""" """
def __init__(self, voxel_size, point_cloud_range, average_points: bool): def __init__(self, voxel_size: List, point_cloud_range: List,
average_points: bool):
super().__init__() super().__init__()
self.voxel_size = voxel_size self.voxel_size = voxel_size
self.point_cloud_range = point_cloud_range self.point_cloud_range = point_cloud_range
self.average_points = average_points self.average_points = average_points
def forward_single(self, points, coors): def forward_single(
self, points: torch.Tensor,
coors: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Scatters points into voxels. """Scatters points into voxels.
Args: Args:
...@@ -97,7 +107,8 @@ class DynamicScatter(nn.Module): ...@@ -97,7 +107,8 @@ class DynamicScatter(nn.Module):
reduce = 'mean' if self.average_points else 'max' reduce = 'mean' if self.average_points else 'max'
return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce) return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce)
def forward(self, points, coors): def forward(self, points: torch.Tensor,
coors: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Scatters points/features into voxels. """Scatters points/features into voxels.
Args: Args:
......
...@@ -11,7 +11,9 @@ ...@@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any
import torch
from torch.autograd import Function from torch.autograd import Function
from . import sparse_ops as ops from . import sparse_ops as ops
...@@ -25,8 +27,9 @@ class SparseConvFunction(Function): ...@@ -25,8 +27,9 @@ class SparseConvFunction(Function):
""" """
@staticmethod @staticmethod
def forward(ctx, features, filters, indice_pairs, indice_pair_num, def forward(ctx: Any, features: torch.Tensor, filters: torch.nn.Parameter,
num_activate_out): indice_pairs: torch.Tensor, indice_pair_num: torch.Tensor,
num_activate_out: torch.Tensor) -> torch.Tensor:
""" """
Args: Args:
features (torch.Tensor): Features that needs to convolute. features (torch.Tensor): Features that needs to convolute.
...@@ -44,7 +47,7 @@ class SparseConvFunction(Function): ...@@ -44,7 +47,7 @@ class SparseConvFunction(Function):
indice_pair_num, num_activate_out, False) indice_pair_num, num_activate_out, False)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
input_bp, filters_bp = ops.indice_conv_backward( input_bp, filters_bp = ops.indice_conv_backward(
features, filters, grad_output, indice_pairs, indice_pair_num, features, filters, grad_output, indice_pairs, indice_pair_num,
...@@ -56,8 +59,9 @@ class SparseConvFunction(Function): ...@@ -56,8 +59,9 @@ class SparseConvFunction(Function):
class SparseInverseConvFunction(Function): class SparseInverseConvFunction(Function):
@staticmethod @staticmethod
def forward(ctx, features, filters, indice_pairs, indice_pair_num, def forward(ctx: Any, features: torch.Tensor, filters: torch.nn.Parameter,
num_activate_out): indice_pairs: torch.Tensor, indice_pair_num: torch.Tensor,
num_activate_out: torch.Tensor) -> torch.Tensor:
""" """
Args: Args:
features (torch.Tensor): Features that needs to convolute. features (torch.Tensor): Features that needs to convolute.
...@@ -75,7 +79,7 @@ class SparseInverseConvFunction(Function): ...@@ -75,7 +79,7 @@ class SparseInverseConvFunction(Function):
indice_pair_num, num_activate_out, True, False) indice_pair_num, num_activate_out, True, False)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
input_bp, filters_bp = ops.indice_conv_backward( input_bp, filters_bp = ops.indice_conv_backward(
features, filters, grad_output, indice_pairs, indice_pair_num, features, filters, grad_output, indice_pairs, indice_pair_num,
...@@ -87,8 +91,9 @@ class SparseInverseConvFunction(Function): ...@@ -87,8 +91,9 @@ class SparseInverseConvFunction(Function):
class SubMConvFunction(Function): class SubMConvFunction(Function):
@staticmethod @staticmethod
def forward(ctx, features, filters, indice_pairs, indice_pair_num, def forward(ctx: Any, features: torch.Tensor, filters: torch.nn.Parameter,
num_activate_out): indice_pairs: torch.Tensor, indice_pair_num: torch.Tensor,
num_activate_out: torch.Tensor) -> torch.Tensor:
""" """
Args: Args:
features (torch.Tensor): Features that needs to convolute. features (torch.Tensor): Features that needs to convolute.
...@@ -106,7 +111,7 @@ class SubMConvFunction(Function): ...@@ -106,7 +111,7 @@ class SubMConvFunction(Function):
indice_pair_num, num_activate_out, False, True) indice_pair_num, num_activate_out, False, True)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
input_bp, filters_bp = ops.indice_conv_backward( input_bp, filters_bp = ops.indice_conv_backward(
features, filters, grad_output, indice_pairs, indice_pair_num, features, filters, grad_output, indice_pairs, indice_pair_num,
...@@ -118,8 +123,9 @@ class SubMConvFunction(Function): ...@@ -118,8 +123,9 @@ class SubMConvFunction(Function):
class SparseMaxPoolFunction(Function): class SparseMaxPoolFunction(Function):
@staticmethod @staticmethod
def forward(ctx, features, indice_pairs, indice_pair_num, def forward(ctx, features: torch.Tensor, indice_pairs: torch.Tensor,
num_activate_out): indice_pair_num: torch.Tensor,
num_activate_out: torch.Tensor) -> torch.Tensor:
""" """
Args: Args:
features (torch.Tensor): Features that needs to convolute. features (torch.Tensor): Features that needs to convolute.
...@@ -137,7 +143,7 @@ class SparseMaxPoolFunction(Function): ...@@ -137,7 +143,7 @@ class SparseMaxPoolFunction(Function):
return out return out
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
indice_pairs, indice_pair_num, features, out = ctx.saved_tensors indice_pairs, indice_pair_num, features, out = ctx.saved_tensors
input_bp = ops.indice_maxpool_backward(features, out, grad_output, input_bp = ops.indice_maxpool_backward(features, out, grad_output,
indice_pairs, indice_pair_num) indice_pairs, indice_pair_num)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from typing import Any, List, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -20,17 +21,18 @@ from torch import nn ...@@ -20,17 +21,18 @@ from torch import nn
from .sparse_structure import SparseConvTensor from .sparse_structure import SparseConvTensor
def is_spconv_module(module): def is_spconv_module(module: nn.Module) -> bool:
spconv_modules = (SparseModule, ) spconv_modules = (SparseModule, )
return isinstance(module, spconv_modules) return isinstance(module, spconv_modules)
def is_sparse_conv(module): def is_sparse_conv(module: nn.Module) -> bool:
from .sparse_conv import SparseConvolution from .sparse_conv import SparseConvolution
return isinstance(module, SparseConvolution) return isinstance(module, SparseConvolution)
def _mean_update(vals, m_vals, t): def _mean_update(vals: Union[int, List], m_vals: Union[int, List],
t: float) -> List:
outputs = [] outputs = []
if not isinstance(vals, list): if not isinstance(vals, list):
vals = [vals] vals = [vals]
...@@ -101,7 +103,7 @@ class SparseSequential(SparseModule): ...@@ -101,7 +103,7 @@ class SparseSequential(SparseModule):
self.add_module(name, module) self.add_module(name, module)
self._sparity_dict = {} self._sparity_dict = {}
def __getitem__(self, idx): def __getitem__(self, idx: int) -> torch.Tensor:
if not (-len(self) <= idx < len(self)): if not (-len(self) <= idx < len(self)):
raise IndexError(f'index {idx} is out of range') raise IndexError(f'index {idx} is out of range')
if idx < 0: if idx < 0:
...@@ -118,14 +120,14 @@ class SparseSequential(SparseModule): ...@@ -118,14 +120,14 @@ class SparseSequential(SparseModule):
def sparity_dict(self): def sparity_dict(self):
return self._sparity_dict return self._sparity_dict
def add(self, module, name=None): def add(self, module: Any, name: Optional[str] = None) -> None:
if name is None: if name is None:
name = str(len(self._modules)) name = str(len(self._modules))
if name in self._modules: if name in self._modules:
raise KeyError('name exists') raise KeyError('name exists')
self.add_module(name, module) self.add_module(name, module)
def forward(self, input): def forward(self, input: torch.Tensor) -> torch.Tensor:
for k, module in self._modules.items(): for k, module in self._modules.items():
if is_spconv_module(module): if is_spconv_module(module):
assert isinstance(input, SparseConvTensor) assert isinstance(input, SparseConvTensor)
......
from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
def scatter_nd(indices, updates, shape): def scatter_nd(indices: torch.Tensor, updates: torch.Tensor,
shape: torch.Tensor) -> torch.Tensor:
"""pytorch edition of tensorflow scatter_nd. """pytorch edition of tensorflow scatter_nd.
this function don't contain except handle code. so use this carefully when this function don't contain except handle code. so use this carefully when
...@@ -21,18 +24,18 @@ def scatter_nd(indices, updates, shape): ...@@ -21,18 +24,18 @@ def scatter_nd(indices, updates, shape):
class SparseConvTensor: class SparseConvTensor:
def __init__(self, def __init__(self,
features, features: torch.Tensor,
indices, indices: torch.Tensor,
spatial_shape, spatial_shape: Union[List, Tuple],
batch_size, batch_size: int,
grid=None): grid: Optional[torch.Tensor] = None):
self.features = features self.features = features
self.indices = indices self.indices = indices
if self.indices.dtype != torch.int32: if self.indices.dtype != torch.int32:
self.indices.int() self.indices.int()
self.spatial_shape = spatial_shape self.spatial_shape = spatial_shape
self.batch_size = batch_size self.batch_size = batch_size
self.indice_dict = {} self.indice_dict: dict = {}
self.grid = grid self.grid = grid
@property @property
...@@ -46,7 +49,7 @@ class SparseConvTensor: ...@@ -46,7 +49,7 @@ class SparseConvTensor:
return self.indice_dict[key] return self.indice_dict[key]
return None return None
def dense(self, channels_first=True): def dense(self, channels_first: bool = True) -> torch.Tensor:
output_shape = [self.batch_size] + list( output_shape = [self.batch_size] + list(
self.spatial_shape) + [self.features.shape[1]] self.spatial_shape) + [self.features.shape[1]]
res = scatter_nd(self.indices.long(), self.features, output_shape) res = scatter_nd(self.indices.long(), self.features, output_shape)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
...@@ -35,8 +37,10 @@ class SyncBatchNormFunction(Function): ...@@ -35,8 +37,10 @@ class SyncBatchNormFunction(Function):
stats_mode=stats_mode) stats_mode=stats_mode)
@staticmethod @staticmethod
def forward(self, input, running_mean, running_var, weight, bias, momentum, def forward(self, input: torch.Tensor, running_mean: torch.Tensor,
eps, group, group_size, stats_mode): running_var: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor, momentum: float, eps: float, group: int,
group_size: int, stats_mode: str) -> torch.Tensor:
self.momentum = momentum self.momentum = momentum
self.eps = eps self.eps = eps
self.group = group self.group = group
...@@ -126,7 +130,7 @@ class SyncBatchNormFunction(Function): ...@@ -126,7 +130,7 @@ class SyncBatchNormFunction(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
def backward(self, grad_output): def backward(self, grad_output: torch.Tensor) -> tuple:
norm, std, weight = self.saved_tensors norm, std, weight = self.saved_tensors
grad_weight = torch.zeros_like(weight) grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(weight) grad_bias = torch.zeros_like(weight)
...@@ -191,13 +195,13 @@ class SyncBatchNorm(Module): ...@@ -191,13 +195,13 @@ class SyncBatchNorm(Module):
""" """
def __init__(self, def __init__(self,
num_features, num_features: int,
eps=1e-5, eps: float = 1e-5,
momentum=0.1, momentum: float = 0.1,
affine=True, affine: bool = True,
track_running_stats=True, track_running_stats: bool = True,
group=None, group: Optional[int] = None,
stats_mode='default'): stats_mode: str = 'default'):
super().__init__() super().__init__()
self.num_features = num_features self.num_features = num_features
self.eps = eps self.eps = eps
...@@ -239,7 +243,7 @@ class SyncBatchNorm(Module): ...@@ -239,7 +243,7 @@ class SyncBatchNorm(Module):
self.weight.data.uniform_() # pytorch use ones_() self.weight.data.uniform_() # pytorch use ones_()
self.bias.data.zero_() self.bias.data.zero_()
def forward(self, input): def forward(self, input: torch.Tensor) -> torch.Tensor:
if input.dim() < 2: if input.dim() < 2:
raise ValueError( raise ValueError(
f'expected at least 2D input, got {input.dim()}D input') f'expected at least 2D input, got {input.dim()}D input')
......
from typing import Tuple from typing import Any, Tuple
import torch import torch
from torch.autograd import Function from torch.autograd import Function
...@@ -17,7 +17,7 @@ class ThreeInterpolate(Function): ...@@ -17,7 +17,7 @@ class ThreeInterpolate(Function):
""" """
@staticmethod @staticmethod
def forward(ctx, features: torch.Tensor, indices: torch.Tensor, def forward(ctx: Any, features: torch.Tensor, indices: torch.Tensor,
weight: torch.Tensor) -> torch.Tensor: weight: torch.Tensor) -> torch.Tensor:
""" """
Args: Args:
......
from typing import Tuple from typing import Any, Tuple
import torch import torch
from torch.autograd import Function from torch.autograd import Function
...@@ -16,7 +16,7 @@ class ThreeNN(Function): ...@@ -16,7 +16,7 @@ class ThreeNN(Function):
""" """
@staticmethod @staticmethod
def forward(ctx, target: torch.Tensor, def forward(ctx: Any, target: torch.Tensor,
source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
......
...@@ -95,6 +95,8 @@ ...@@ -95,6 +95,8 @@
# ======================================================================= # =======================================================================
from typing import Any, List, Tuple, Union
import torch import torch
from torch.autograd import Function from torch.autograd import Function
from torch.nn import functional as F from torch.nn import functional as F
...@@ -108,8 +110,10 @@ upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d']) ...@@ -108,8 +110,10 @@ upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d'])
class UpFirDn2dBackward(Function): class UpFirDn2dBackward(Function):
@staticmethod @staticmethod
def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, def forward(ctx: Any, grad_output: torch.Tensor, kernel: torch.Tensor,
in_size, out_size): grad_kernel: torch.Tensor, up: tuple, down: tuple, pad: tuple,
g_pad: tuple, in_size: Union[List, Tuple],
out_size: Union[List, Tuple]) -> torch.Tensor:
up_x, up_y = up up_x, up_y = up
down_x, down_y = down down_x, down_y = down
...@@ -149,7 +153,7 @@ class UpFirDn2dBackward(Function): ...@@ -149,7 +153,7 @@ class UpFirDn2dBackward(Function):
return grad_input return grad_input
@staticmethod @staticmethod
def backward(ctx, gradgrad_input): def backward(ctx: Any, gradgrad_input: torch.Tensor) -> tuple:
kernel, = ctx.saved_tensors kernel, = ctx.saved_tensors
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2],
...@@ -177,7 +181,8 @@ class UpFirDn2dBackward(Function): ...@@ -177,7 +181,8 @@ class UpFirDn2dBackward(Function):
class UpFirDn2d(Function): class UpFirDn2d(Function):
@staticmethod @staticmethod
def forward(ctx, input, kernel, up, down, pad): def forward(ctx: Any, input: torch.Tensor, kernel: torch.Tensor, up: tuple,
down: tuple, pad: tuple) -> torch.Tensor:
up_x, up_y = up up_x, up_y = up
down_x, down_y = down down_x, down_y = down
pad_x0, pad_x1, pad_y0, pad_y1 = pad pad_x0, pad_x1, pad_y0, pad_y1 = pad
...@@ -222,7 +227,7 @@ class UpFirDn2d(Function): ...@@ -222,7 +227,7 @@ class UpFirDn2d(Function):
return out return out
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
kernel, grad_kernel = ctx.saved_tensors kernel, grad_kernel = ctx.saved_tensors
grad_input = UpFirDn2dBackward.apply( grad_input = UpFirDn2dBackward.apply(
...@@ -240,7 +245,12 @@ class UpFirDn2d(Function): ...@@ -240,7 +245,12 @@ class UpFirDn2d(Function):
return grad_input, None, None, None, None return grad_input, None, None, None, None
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): def upfirdn2d(
input: torch.Tensor,
kernel: torch.Tensor,
up: Union[int, tuple] = 1,
down: Union[int, tuple] = 1,
pad: tuple = (0, 0)) -> torch.Tensor: # noqa E125
"""UpFRIDn for 2d features. """UpFRIDn for 2d features.
UpFIRDn is short for upsample, apply FIR filter and downsample. More UpFIRDn is short for upsample, apply FIR filter and downsample. More
...@@ -264,14 +274,14 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): ...@@ -264,14 +274,14 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
""" """
if input.device.type == 'cpu': if input.device.type == 'cpu':
if len(pad) == 2: if len(pad) == 2:
pad = (pad[0], pad[1], pad[0], pad[1]) pad = (pad[0], pad[1], pad[0], pad[1]) # type: ignore
up = to_2tuple(up) _up = to_2tuple(up)
down = to_2tuple(down) _down = to_2tuple(down)
out = upfirdn2d_native(input, kernel, up[0], up[1], down[0], down[1], out = upfirdn2d_native(input, kernel, _up[0], _up[1], _down[0],
pad[0], pad[1], pad[2], pad[3]) _down[1], pad[0], pad[1], pad[2], pad[3])
else: else:
_up = to_2tuple(up) _up = to_2tuple(up)
...@@ -287,8 +297,9 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): ...@@ -287,8 +297,9 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
return out return out
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, def upfirdn2d_native(input: torch.Tensor, kernel: torch.Tensor, up_x: int,
pad_y0, pad_y1): up_y: int, down_x: int, down_y: int, pad_x0: int,
pad_x1: int, pad_y0: int, pad_y1: int) -> torch.Tensor:
_, channel, in_h, in_w = input.shape _, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1) input = input.reshape(-1, in_h, in_w, 1)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from torch.autograd import Function from torch.autograd import Function
...@@ -13,13 +15,14 @@ ext_module = ext_loader.load_ext( ...@@ -13,13 +15,14 @@ ext_module = ext_loader.load_ext(
class _Voxelization(Function): class _Voxelization(Function):
@staticmethod @staticmethod
def forward(ctx, def forward(
points, ctx: Any,
voxel_size, points: torch.Tensor,
coors_range, voxel_size: Union[tuple, float],
max_points=35, coors_range: Union[tuple, float],
max_voxels=20000, max_points: int = 35,
deterministic=True): max_voxels: int = 20000,
deterministic: bool = True) -> Union[Tuple[torch.Tensor], Tuple]:
"""Convert kitti points(N, >=3) to voxels. """Convert kitti points(N, >=3) to voxels.
Args: Args:
...@@ -111,11 +114,11 @@ class Voxelization(nn.Module): ...@@ -111,11 +114,11 @@ class Voxelization(nn.Module):
""" """
def __init__(self, def __init__(self,
voxel_size, voxel_size: List,
point_cloud_range, point_cloud_range: List,
max_num_points, max_num_points: int,
max_voxels=20000, max_voxels: Union[tuple, int] = 20000,
deterministic=True): deterministic: bool = True):
""" """
Args: Args:
voxel_size (list): list [x, y, z] size of three dimension voxel_size (list): list [x, y, z] size of three dimension
...@@ -149,8 +152,9 @@ class Voxelization(nn.Module): ...@@ -149,8 +152,9 @@ class Voxelization(nn.Module):
point_cloud_range = torch.tensor( point_cloud_range = torch.tensor(
point_cloud_range, dtype=torch.float32) point_cloud_range, dtype=torch.float32)
voxel_size = torch.tensor(voxel_size, dtype=torch.float32) voxel_size = torch.tensor(voxel_size, dtype=torch.float32)
grid_size = (point_cloud_range[3:] - grid_size = (
point_cloud_range[:3]) / voxel_size point_cloud_range[3:] - # type: ignore
point_cloud_range[:3]) / voxel_size # type: ignore
grid_size = torch.round(grid_size).long() grid_size = torch.round(grid_size).long()
input_feat_shape = grid_size[:2] input_feat_shape = grid_size[:2]
self.grid_size = grid_size self.grid_size = grid_size
...@@ -158,7 +162,7 @@ class Voxelization(nn.Module): ...@@ -158,7 +162,7 @@ class Voxelization(nn.Module):
# [w, h, d] -> [d, h, w] # [w, h, d] -> [d, h, w]
self.pcd_shape = [*input_feat_shape, 1][::-1] self.pcd_shape = [*input_feat_shape, 1][::-1]
def forward(self, input): def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.training: if self.training:
max_voxels = self.max_voxels[0] max_voxels = self.max_voxels[0]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment