Unverified Commit dadda9ed authored by Dušan Malić's avatar Dušan Malić Committed by GitHub
Browse files

Reflect the changes of the default behavior of grid_sample and affine_grid (#993)



* Reflect the changes of the default behavior of grid_sample and affine_grid. If PyTorch > 1.3.0 align_corners will be set to False by default. This is undesired behavior.

* Support grid_sample and affine_grid for PyTorch<1.3 as well.
Co-authored-by: default avatarDusan Malic <dusan.malic@icg.tugraz.at>
parent 519b1564
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -16,6 +18,11 @@ class Sampler(nn.Module):
self.mode = mode
self.padding_mode = padding_mode
if torch.__version__ >= '1.3':
self.grid_sample = partial(F.grid_sample, align_corners=True)
else:
self.grid_sample = F.grid_sample
def forward(self, input_features, grid):
"""
Samples input using sampling grid
......@@ -26,5 +33,5 @@ class Sampler(nn.Module):
output_features: (B, C, X, Y, Z) Output voxel features
"""
# Sample from grid
output = F.grid_sample(input=input_features, grid=grid, mode=self.mode, padding_mode=self.padding_mode)
output = self.grid_sample(input=input_features, grid=grid, mode=self.mode, padding_mode=self.padding_mode)
return output
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from .roi_head_template import RoIHeadTemplate
from ...utils import common_utils, loss_utils
......@@ -31,6 +34,13 @@ class SECONDHead(RoIHeadTemplate):
)
self.init_weights(weight_init='xavier')
if torch.__version__ >= '1.3':
self.affine_grid = partial(F.affine_grid, align_corners=True)
self.grid_sample = partial(F.grid_sample, align_corners=True)
else:
self.affine_grid = F.affine_grid
self.grid_sample = F.grid_sample
def init_weights(self, weight_init='xavier'):
if weight_init == 'kaiming':
init_func = nn.init.kaiming_normal_
......@@ -92,12 +102,12 @@ class SECONDHead(RoIHeadTemplate):
), dim=1).view(-1, 2, 3).float()
grid_size = self.model_cfg.ROI_GRID_POOL.GRID_SIZE
grid = nn.functional.affine_grid(
grid = self.affine_grid(
theta,
torch.Size((rois.size(1), spatial_features_2d.size(1), grid_size, grid_size))
)
pooled_features = nn.functional.grid_sample(
pooled_features = self.grid_sample(
spatial_features_2d[b_id].unsqueeze(0).expand(rois.size(1), spatial_features_2d.size(1), height, width),
grid
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment