Unverified Commit 2bb1160e authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

[fix]: fix wrapper comparison of pytorch version (#572)

parent afb73995
...@@ -30,7 +30,7 @@ class NewEmptyTensorOp(torch.autograd.Function): ...@@ -30,7 +30,7 @@ class NewEmptyTensorOp(torch.autograd.Function):
class Conv2d(nn.Conv2d): class Conv2d(nn.Conv2d):
def forward(self, x): def forward(self, x):
if x.numel() == 0 and torch.__version__ <= '1.4': if x.numel() == 0 and torch.__version__ <= '1.4.0':
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size, for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
self.padding, self.stride, self.dilation): self.padding, self.stride, self.dilation):
...@@ -72,7 +72,7 @@ class MaxPool2d(nn.MaxPool2d): ...@@ -72,7 +72,7 @@ class MaxPool2d(nn.MaxPool2d):
def forward(self, x): def forward(self, x):
# PyTorch 1.6 does not support empty tensor inference yet # PyTorch 1.6 does not support empty tensor inference yet
if x.numel() == 0 and torch.__version__ <= '1.6': if x.numel() == 0 and torch.__version__ <= '1.6.0':
out_shape = list(x.shape[:2]) out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size), for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
_pair(self.padding), _pair(self.stride), _pair(self.padding), _pair(self.stride),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment