test_revert_syncbn.py 2 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import os
import platform

import numpy as np
import pytest
import torch
import torch.distributed as dist

from mmcv.cnn.bricks import ConvModule
from mmcv.cnn.utils import revert_sync_batchnorm

if platform.system() == 'Windows':
    import regex as re
else:
    import re


pc's avatar
pc committed
19
20
@pytest.mark.skipif(
    torch.__version__ == 'parrots', reason='not supported in parrots now')
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def test_revert_syncbn():
    conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN'))
    x = torch.randn(1, 3, 10, 10)
    # Expect a ValueError prompting that SyncBN is not supported on CPU
    with pytest.raises(ValueError):
        y = conv(x)
    conv = revert_sync_batchnorm(conv)
    y = conv(x)
    assert y.shape == (1, 8, 9, 9)


def test_revert_mmsyncbn():
    if 'SLURM_NTASKS' not in os.environ or int(os.environ['SLURM_NTASKS']) < 2:
        print('Must run on slurm with more than 1 process!\n'
              'srun -p test --gres=gpu:2 -n2')
        return
    rank = int(os.environ['SLURM_PROCID'])
    world_size = int(os.environ['SLURM_NTASKS'])
    local_rank = int(os.environ['SLURM_LOCALID'])
    node_list = str(os.environ['SLURM_NODELIST'])

    node_parts = re.findall('[0-9]+', node_list)
    os.environ['MASTER_ADDR'] = (f'{node_parts[1]}.{node_parts[2]}' +
                                 f'.{node_parts[3]}.{node_parts[4]}')
    os.environ['MASTER_PORT'] = '12341'
    os.environ['WORLD_SIZE'] = str(world_size)
    os.environ['RANK'] = str(rank)

    dist.init_process_group('nccl')
    torch.cuda.set_device(local_rank)
    x = torch.randn(1, 3, 10, 10).cuda()
    dist.broadcast(x, src=0)
    conv = ConvModule(3, 8, 2, norm_cfg=dict(type='MMSyncBN')).cuda()
    conv.eval()
    y_mmsyncbn = conv(x).detach().cpu().numpy()
    conv = revert_sync_batchnorm(conv)
    y_bn = conv(x).detach().cpu().numpy()
    assert np.all(np.isclose(y_bn, y_mmsyncbn, 1e-3))
    conv, x = conv.to('cpu'), x.to('cpu')
    y_bn_cpu = conv(x).detach().numpy()
    assert np.all(np.isclose(y_bn, y_bn_cpu, 1e-3))