Unverified Commit dfa36dfe authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Merge pull request #652 from dreamerlin/3d

[Feature] Add 3D support in wrapper
parents ec43b671 1a12ac75
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .alexnet import AlexNet from .alexnet import AlexNet
# yapf: disable
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS, PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
ContextBlock, Conv2d, ConvAWS2d, ConvModule, ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
ConvTranspose2d, ConvWS2d, DepthwiseSeparableConvModule, ConvTranspose2d, ConvTranspose3d, ConvWS2d,
GeneralizedAttention, HSigmoid, HSwish, Linear, MaxPool2d, DepthwiseSeparableConvModule, GeneralizedAttention,
HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish, NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
build_activation_layer, build_conv_layer, build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_plugin_layer, build_norm_layer, build_padding_layer, build_plugin_layer,
build_upsample_layer, conv_ws_2d, is_norm) build_upsample_layer, conv_ws_2d, is_norm)
# yapf: enable
from .resnet import ResNet, make_res_layer from .resnet import ResNet, make_res_layer
from .utils import (bias_init_with_prob, caffe2_xavier_init, constant_init, from .utils import (bias_init_with_prob, caffe2_xavier_init, constant_init,
fuse_conv_bn, get_model_complexity_info, kaiming_init, fuse_conv_bn, get_model_complexity_info, kaiming_init,
...@@ -26,5 +29,6 @@ __all__ = [ ...@@ -26,5 +29,6 @@ __all__ = [
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d', 'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d' 'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
'MaxPool3d', 'Conv3d'
] ]
...@@ -17,7 +17,8 @@ from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, ...@@ -17,7 +17,8 @@ from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
from .scale import Scale from .scale import Scale
from .swish import Swish from .swish import Swish
from .upsample import build_upsample_layer from .upsample import build_upsample_layer
from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
Linear, MaxPool2d, MaxPool3d)
__all__ = [ __all__ = [
'ConvModule', 'build_activation_layer', 'build_conv_layer', 'ConvModule', 'build_activation_layer', 'build_conv_layer',
...@@ -27,5 +28,6 @@ __all__ = [ ...@@ -27,5 +28,6 @@ __all__ = [
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear', 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d' 'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
'ConvTranspose3d', 'MaxPool3d', 'Conv3d'
] ]
...@@ -8,7 +8,7 @@ import math ...@@ -8,7 +8,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair, _triple
from .registry import CONV_LAYERS, UPSAMPLE_LAYERS from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
...@@ -58,6 +58,27 @@ class Conv2d(nn.Conv2d): ...@@ -58,6 +58,27 @@ class Conv2d(nn.Conv2d):
return super().forward(x) return super().forward(x)
@CONV_LAYERS.register_module('Conv3d', force=True)
class Conv3d(nn.Conv3d):
def forward(self, x):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride, self.dilation):
o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
out_shape.append(o)
empty = NewEmptyTensorOp.apply(x, out_shape)
if self.training:
# produce dummy gradient to avoid DDP warning.
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
return empty + dummy
else:
return empty
return super().forward(x)
@CONV_LAYERS.register_module() @CONV_LAYERS.register_module()
@CONV_LAYERS.register_module('deconv') @CONV_LAYERS.register_module('deconv')
@UPSAMPLE_LAYERS.register_module('deconv', force=True) @UPSAMPLE_LAYERS.register_module('deconv', force=True)
...@@ -78,7 +99,30 @@ class ConvTranspose2d(nn.ConvTranspose2d): ...@@ -78,7 +99,30 @@ class ConvTranspose2d(nn.ConvTranspose2d):
else: else:
return empty return empty
return super(ConvTranspose2d, self).forward(x) return super().forward(x)
@CONV_LAYERS.register_module()
@CONV_LAYERS.register_module('deconv3d')
@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
class ConvTranspose3d(nn.ConvTranspose3d):
def forward(self, x):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride,
self.dilation, self.output_padding):
out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
empty = NewEmptyTensorOp.apply(x, out_shape)
if self.training:
# produce dummy gradient to avoid DDP warning.
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
return empty + dummy
else:
return empty
return super().forward(x)
class MaxPool2d(nn.MaxPool2d): class MaxPool2d(nn.MaxPool2d):
...@@ -99,6 +143,25 @@ class MaxPool2d(nn.MaxPool2d): ...@@ -99,6 +143,25 @@ class MaxPool2d(nn.MaxPool2d):
return super().forward(x) return super().forward(x)
class MaxPool3d(nn.MaxPool3d):
def forward(self, x):
# PyTorch 1.7 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 7)):
out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
_triple(self.padding),
_triple(self.stride),
_triple(self.dilation)):
o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
o = math.ceil(o) if self.ceil_mode else math.floor(o)
out_shape.append(o)
empty = NewEmptyTensorOp.apply(x, out_shape)
return empty
return super().forward(x)
class Linear(torch.nn.Linear): class Linear(torch.nn.Linear):
def forward(self, x): def forward(self, x):
......
from collections import OrderedDict
from itertools import product
from unittest.mock import patch from unittest.mock import patch
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn.bricks import Conv2d, ConvTranspose2d, Linear, MaxPool2d from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
Linear, MaxPool2d, MaxPool3d)
@patch('torch.__version__', '1.1') @patch('torch.__version__', '1.1')
def test_conv2d(): @pytest.mark.parametrize(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
[(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
def test_conv2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
padding, dilation):
""" """
CommandLine: CommandLine:
xdoctest -m tests/test_wrappers.py test_conv2d xdoctest -m tests/test_wrappers.py test_conv2d
""" """
# train mode
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_h, in_w)
torch.manual_seed(0)
wrapper = Conv2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_channel, in_h, in_w).requires_grad_(True)
torch.manual_seed(0)
ref = nn.Conv2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]), # eval mode
('in_channel', [1, 3]), ('out_channel', [1, 3]), x_empty = torch.randn(0, in_channel, in_h, in_w)
('kernel_size', [3, 5]), ('stride', [1, 2]), wrapper = Conv2d(
('padding', [0, 1]), ('dilation', [1, 2])]) in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
wrapper.eval()
wrapper(x_empty)
# train mode
for in_h, in_w, in_cha, out_cha, k, s, p, d in product(
*list(test_cases.values())):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_cha, in_h, in_w)
torch.manual_seed(0)
wrapper = Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference @patch('torch.__version__', '1.1')
x_normal = torch.randn(3, in_cha, in_h, in_w).requires_grad_(True) @pytest.mark.parametrize(
torch.manual_seed(0) 'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501
ref = nn.Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d) [(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
ref_out = ref(x_normal) def test_conv3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size, stride,
padding, dilation):
"""
CommandLine:
xdoctest -m tests/test_wrappers.py test_conv3d
"""
# train mode
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w)
torch.manual_seed(0)
wrapper = Conv3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_channel, in_t, in_h,
in_w).requires_grad_(True)
torch.manual_seed(0)
ref = nn.Conv3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
assert wrapper_out.shape[0] == 0 # eval mode
assert wrapper_out.shape[1:] == ref_out.shape[1:] x_empty = torch.randn(0, in_channel, in_t, in_h, in_w)
wrapper = Conv3d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
wrapper.eval()
wrapper(x_empty)
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out) @patch('torch.__version__', '1.1')
@pytest.mark.parametrize(
'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
[(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
def test_conv_transposed_2d(in_w, in_h, in_channel, out_channel, kernel_size,
stride, padding, dilation):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_h, in_w, requires_grad=True)
# out padding must be smaller than either stride or dilation
op = min(stride, dilation) - 1
torch.manual_seed(0)
wrapper = ConvTranspose2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_channel, in_h, in_w)
torch.manual_seed(0)
ref = nn.ConvTranspose2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
ref_out = ref(x_normal)
assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode # eval mode
x_empty = torch.randn(0, in_cha, in_h, in_w) x_empty = torch.randn(0, in_channel, in_h, in_w)
wrapper = Conv2d(in_cha, out_cha, k, stride=s, padding=p, dilation=d) wrapper = ConvTranspose2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
wrapper.eval() wrapper.eval()
wrapper(x_empty) wrapper(x_empty)
@patch('torch.__version__', '1.1') @patch('torch.__version__', '1.1')
def test_conv_transposed_2d(): @pytest.mark.parametrize(
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]), 'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501
('in_channel', [1, 3]), ('out_channel', [1, 3]), [(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
('kernel_size', [3, 5]), ('stride', [1, 2]), def test_conv_transposed_3d(in_w, in_h, in_t, in_channel, out_channel,
('padding', [0, 1]), ('dilation', [1, 2])]) kernel_size, stride, padding, dilation):
# wrapper op with 0-dim input
for in_h, in_w, in_cha, out_cha, k, s, p, d in product( x_empty = torch.randn(0, in_channel, in_t, in_h, in_w, requires_grad=True)
*list(test_cases.values())): # out padding must be smaller than either stride or dilation
# wrapper op with 0-dim input op = min(stride, dilation) - 1
x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True) torch.manual_seed(0)
# out padding must be smaller than either stride or dilation wrapper = ConvTranspose3d(
op = min(s, d) - 1 in_channel,
torch.manual_seed(0) out_channel,
wrapper = ConvTranspose2d( kernel_size,
in_cha, stride=stride,
out_cha, padding=padding,
k, dilation=dilation,
stride=s, output_padding=op)
padding=p, wrapper_out = wrapper(x_empty)
dilation=d,
output_padding=op) # torch op with 3-dim input as shape reference
wrapper_out = wrapper(x_empty) x_normal = torch.randn(3, in_channel, in_t, in_h, in_w)
torch.manual_seed(0)
# torch op with 3-dim input as shape reference ref = nn.ConvTranspose3d(
x_normal = torch.randn(3, in_cha, in_h, in_w) in_channel,
torch.manual_seed(0) out_channel,
ref = nn.ConvTranspose2d( kernel_size,
in_cha, stride=stride,
out_cha, padding=padding,
k, dilation=dilation,
stride=s, output_padding=op)
padding=p, ref_out = ref(x_normal)
dilation=d,
output_padding=op) assert wrapper_out.shape[0] == 0
ref_out = ref(x_normal) assert wrapper_out.shape[1:] == ref_out.shape[1:]
assert wrapper_out.shape[0] == 0 wrapper_out.sum().backward()
assert wrapper_out.shape[1:] == ref_out.shape[1:] assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None assert torch.equal(wrapper(x_normal), ref_out)
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode # eval mode
x_empty = torch.randn(0, in_cha, in_h, in_w) x_empty = torch.randn(0, in_channel, in_t, in_h, in_w)
wrapper = ConvTranspose2d( wrapper = ConvTranspose3d(
in_cha, out_cha, k, stride=s, padding=p, dilation=d, output_padding=op) in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=op)
wrapper.eval() wrapper.eval()
wrapper(x_empty) wrapper(x_empty)
@patch('torch.__version__', '1.1') @patch('torch.__version__', '1.1')
def test_max_pool_2d(): @pytest.mark.parametrize(
test_cases = OrderedDict([('in_w', [10, 20]), ('in_h', [10, 20]), 'in_w,in_h,in_channel,out_channel,kernel_size,stride,padding,dilation',
('in_channel', [1, 3]), ('out_channel', [1, 3]), [(10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 3, 3, 5, 2, 1, 2)])
('kernel_size', [3, 5]), ('stride', [1, 2]), def test_max_pool_2d(in_w, in_h, in_channel, out_channel, kernel_size, stride,
('padding', [0, 1]), ('dilation', [1, 2])]) padding, dilation):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_h, in_w, requires_grad=True)
wrapper = MaxPool2d(
kernel_size, stride=stride, padding=padding, dilation=dilation)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference
x_normal = torch.randn(3, in_channel, in_h, in_w)
ref = nn.MaxPool2d(
kernel_size, stride=stride, padding=padding, dilation=dilation)
ref_out = ref(x_normal)
for in_h, in_w, in_cha, out_cha, k, s, p, d in product( assert wrapper_out.shape[0] == 0
*list(test_cases.values())): assert wrapper_out.shape[1:] == ref_out.shape[1:]
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_cha, in_h, in_w, requires_grad=True)
wrapper = MaxPool2d(k, stride=s, padding=p, dilation=d)
wrapper_out = wrapper(x_empty)
# torch op with 3-dim input as shape reference assert torch.equal(wrapper(x_normal), ref_out)
x_normal = torch.randn(3, in_cha, in_h, in_w)
ref = nn.MaxPool2d(k, stride=s, padding=p, dilation=d)
ref_out = ref(x_normal) @patch('torch.__version__', '1.1')
@pytest.mark.parametrize(
'in_w,in_h,in_t,in_channel,out_channel,kernel_size,stride,padding,dilation', # noqa: E501
[(10, 10, 10, 1, 1, 3, 1, 0, 1), (20, 20, 20, 3, 3, 5, 2, 1, 2)])
def test_max_pool_3d(in_w, in_h, in_t, in_channel, out_channel, kernel_size,
stride, padding, dilation):
# wrapper op with 0-dim input
x_empty = torch.randn(0, in_channel, in_t, in_h, in_w, requires_grad=True)
wrapper = MaxPool3d(
kernel_size, stride=stride, padding=padding, dilation=dilation)
wrapper_out = wrapper(x_empty)
assert wrapper_out.shape[0] == 0 # torch op with 3-dim input as shape reference
assert wrapper_out.shape[1:] == ref_out.shape[1:] x_normal = torch.randn(3, in_channel, in_t, in_h, in_w)
ref = nn.MaxPool3d(
kernel_size, stride=stride, padding=padding, dilation=dilation)
ref_out = ref(x_normal)
assert torch.equal(wrapper(x_normal), ref_out) assert wrapper_out.shape[0] == 0
assert wrapper_out.shape[1:] == ref_out.shape[1:]
assert torch.equal(wrapper(x_normal), ref_out)
@patch('torch.__version__', '1.1') @patch('torch.__version__', '1.1')
def test_linear(): @pytest.mark.parametrize('in_w,in_h,in_feature,out_feature', [(10, 10, 1, 1),
test_cases = OrderedDict([ (20, 20, 3, 3)])
('in_w', [10, 20]), def test_linear(in_w, in_h, in_feature, out_feature):
('in_h', [10, 20]), # wrapper op with 0-dim input
('in_feature', [1, 3]), x_empty = torch.randn(0, in_feature, requires_grad=True)
('out_feature', [1, 3]), torch.manual_seed(0)
]) wrapper = Linear(in_feature, out_feature)
wrapper_out = wrapper(x_empty)
for in_h, in_w, in_feature, out_feature in product(
*list(test_cases.values())): # torch op with 3-dim input as shape reference
# wrapper op with 0-dim input x_normal = torch.randn(3, in_feature)
x_empty = torch.randn(0, in_feature, requires_grad=True) torch.manual_seed(0)
torch.manual_seed(0) ref = nn.Linear(in_feature, out_feature)
wrapper = Linear(in_feature, out_feature) ref_out = ref(x_normal)
wrapper_out = wrapper(x_empty)
assert wrapper_out.shape[0] == 0
# torch op with 3-dim input as shape reference assert wrapper_out.shape[1:] == ref_out.shape[1:]
x_normal = torch.randn(3, in_feature)
torch.manual_seed(0) wrapper_out.sum().backward()
ref = nn.Linear(in_feature, out_feature) assert wrapper.weight.grad is not None
ref_out = ref(x_normal) assert wrapper.weight.grad.shape == wrapper.weight.shape
assert wrapper_out.shape[0] == 0 assert torch.equal(wrapper(x_normal), ref_out)
assert wrapper_out.shape[1:] == ref_out.shape[1:]
wrapper_out.sum().backward()
assert wrapper.weight.grad is not None
assert wrapper.weight.grad.shape == wrapper.weight.shape
assert torch.equal(wrapper(x_normal), ref_out)
# eval mode # eval mode
x_empty = torch.randn(0, in_feature) x_empty = torch.randn(0, in_feature)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment