test_masked_conv2d.py 448 Bytes
Newer Older
1
2
import torch

3

limm's avatar
limm committed
4
class TestMaskedConv2d(object):
5

limm's avatar
limm committed
6
7
8
    def test_masked_conv2d(self):
        if not torch.cuda.is_available():
            return
9
        from mmcv.ops import MaskedConv2d
limm's avatar
limm committed
10
11
12
        input = torch.randn(1, 3, 16, 16, requires_grad=True, device='cuda')
        mask = torch.randn(1, 16, 16, requires_grad=True, device='cuda')
        conv = MaskedConv2d(3, 3, 3).cuda()
13
        output = conv(input, mask)
limm's avatar
limm committed
14
        assert output is not None