test_masked_conv2d.py 488 Bytes
Newer Older
limm's avatar
limm committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
import torch

4

limm's avatar
limm committed
5
class TestMaskedConv2d:
6

limm's avatar
limm committed
7
8
9
    def test_masked_conv2d(self):
        if not torch.cuda.is_available():
            return
10
        from mmcv.ops import MaskedConv2d
limm's avatar
limm committed
11
12
13
        input = torch.randn(1, 3, 16, 16, requires_grad=True, device='cuda')
        mask = torch.randn(1, 16, 16, requires_grad=True, device='cuda')
        conv = MaskedConv2d(3, 3, 3).cuda()
14
        output = conv(input, mask)
limm's avatar
limm committed
15
        assert output is not None