Unverified Commit dadda9ed authored by Dušan Malić's avatar Dušan Malić Committed by GitHub
Browse files

Reflect the changes of the default behavior of grid_sample and affine_grid (#993)



* Reflect the changes of the default behavior of grid_sample and affine_grid. If PyTorch > 1.3.0 align_corners will be set to False by default. This is undesired behavior.

* Support grid_sample and affine_grid for PyTorch<1.3 as well.
Co-authored-by: default avatarDusan Malic <dusan.malic@icg.tugraz.at>
parent 519b1564
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -16,6 +18,11 @@ class Sampler(nn.Module): ...@@ -16,6 +18,11 @@ class Sampler(nn.Module):
self.mode = mode self.mode = mode
self.padding_mode = padding_mode self.padding_mode = padding_mode
if torch.__version__ >= '1.3':
self.grid_sample = partial(F.grid_sample, align_corners=True)
else:
self.grid_sample = F.grid_sample
def forward(self, input_features, grid): def forward(self, input_features, grid):
""" """
Samples input using sampling grid Samples input using sampling grid
...@@ -26,5 +33,5 @@ class Sampler(nn.Module): ...@@ -26,5 +33,5 @@ class Sampler(nn.Module):
output_features: (B, C, X, Y, Z) Output voxel features output_features: (B, C, X, Y, Z) Output voxel features
""" """
# Sample from grid # Sample from grid
output = F.grid_sample(input=input_features, grid=grid, mode=self.mode, padding_mode=self.padding_mode) output = self.grid_sample(input=input_features, grid=grid, mode=self.mode, padding_mode=self.padding_mode)
return output return output
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from .roi_head_template import RoIHeadTemplate from .roi_head_template import RoIHeadTemplate
from ...utils import common_utils, loss_utils from ...utils import common_utils, loss_utils
...@@ -31,6 +34,13 @@ class SECONDHead(RoIHeadTemplate): ...@@ -31,6 +34,13 @@ class SECONDHead(RoIHeadTemplate):
) )
self.init_weights(weight_init='xavier') self.init_weights(weight_init='xavier')
if torch.__version__ >= '1.3':
self.affine_grid = partial(F.affine_grid, align_corners=True)
self.grid_sample = partial(F.grid_sample, align_corners=True)
else:
self.affine_grid = F.affine_grid
self.grid_sample = F.grid_sample
def init_weights(self, weight_init='xavier'): def init_weights(self, weight_init='xavier'):
if weight_init == 'kaiming': if weight_init == 'kaiming':
init_func = nn.init.kaiming_normal_ init_func = nn.init.kaiming_normal_
...@@ -92,12 +102,12 @@ class SECONDHead(RoIHeadTemplate): ...@@ -92,12 +102,12 @@ class SECONDHead(RoIHeadTemplate):
), dim=1).view(-1, 2, 3).float() ), dim=1).view(-1, 2, 3).float()
grid_size = self.model_cfg.ROI_GRID_POOL.GRID_SIZE grid_size = self.model_cfg.ROI_GRID_POOL.GRID_SIZE
grid = nn.functional.affine_grid( grid = self.affine_grid(
theta, theta,
torch.Size((rois.size(1), spatial_features_2d.size(1), grid_size, grid_size)) torch.Size((rois.size(1), spatial_features_2d.size(1), grid_size, grid_size))
) )
pooled_features = nn.functional.grid_sample( pooled_features = self.grid_sample(
spatial_features_2d[b_id].unsqueeze(0).expand(rois.size(1), spatial_features_2d.size(1), height, width), spatial_features_2d[b_id].unsqueeze(0).expand(rois.size(1), spatial_features_2d.size(1), height, width),
grid grid
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment