Unverified Commit 039e06e4 authored by tianyuandu's avatar tianyuandu Committed by GitHub
Browse files

use torch.cummax in corner_pool for torch 1.5+ (#390)

* synchronize from mmdetection

* fix parrots
parent b11c5660
import torch
from torch import nn from torch import nn
from torch.autograd import Function from torch.autograd import Function
...@@ -98,10 +99,27 @@ class CornerPool(nn.Module): ...@@ -98,10 +99,27 @@ class CornerPool(nn.Module):
'top': TopPoolFunction, 'top': TopPoolFunction,
} }
cummax_dim_flip = {
'bottom': (2, False),
'left': (3, True),
'right': (3, False),
'top': (2, True),
}
def __init__(self, mode): def __init__(self, mode):
super(CornerPool, self).__init__() super(CornerPool, self).__init__()
assert mode in self.pool_functions assert mode in self.pool_functions
self.mode = mode
self.corner_pool = self.pool_functions[mode] self.corner_pool = self.pool_functions[mode]
def forward(self, x): def forward(self, x):
if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0':
dim, flip = self.cummax_dim_flip[self.mode]
if flip:
x = x.flip(dim)
pool_tensor, _ = torch.cummax(x, dim=dim)
if flip:
pool_tensor = pool_tensor.flip(dim)
return pool_tensor
else:
return self.corner_pool.apply(x) return self.corner_pool.apply(x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment