test_revert_syncbn.py 2 KB
Newer Older
limm's avatar
limm committed
1
# Copyright (c) OpenMMLab. All rights reserved.
limm's avatar
limm committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import os
import platform

import numpy as np
import pytest
import torch
import torch.distributed as dist

from mmcv.cnn.bricks import ConvModule
from mmcv.cnn.utils import revert_sync_batchnorm

if platform.system() == 'Windows':
    import regex as re
else:
    import re


limm's avatar
limm committed
19
20
@pytest.mark.skipif(
    torch.__version__ == 'parrots', reason='not supported in parrots now')
limm's avatar
limm committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def test_revert_syncbn():
    conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN'))
    x = torch.randn(1, 3, 10, 10)
    # Expect a ValueError prompting that SyncBN is not supported on CPU
    with pytest.raises(ValueError):
        y = conv(x)
    conv = revert_sync_batchnorm(conv)
    y = conv(x)
    assert y.shape == (1, 8, 9, 9)


def test_revert_mmsyncbn():
    if 'SLURM_NTASKS' not in os.environ or int(os.environ['SLURM_NTASKS']) < 2:
        print('Must run on slurm with more than 1 process!\n'
              'srun -p test --gres=gpu:2 -n2')
        return
    rank = int(os.environ['SLURM_PROCID'])
    world_size = int(os.environ['SLURM_NTASKS'])
    local_rank = int(os.environ['SLURM_LOCALID'])
    node_list = str(os.environ['SLURM_NODELIST'])

    node_parts = re.findall('[0-9]+', node_list)
    os.environ['MASTER_ADDR'] = (f'{node_parts[1]}.{node_parts[2]}' +
                                 f'.{node_parts[3]}.{node_parts[4]}')
    os.environ['MASTER_PORT'] = '12341'
    os.environ['WORLD_SIZE'] = str(world_size)
    os.environ['RANK'] = str(rank)

    dist.init_process_group('nccl')
    torch.cuda.set_device(local_rank)
    x = torch.randn(1, 3, 10, 10).cuda()
    dist.broadcast(x, src=0)
    conv = ConvModule(3, 8, 2, norm_cfg=dict(type='MMSyncBN')).cuda()
    conv.eval()
    y_mmsyncbn = conv(x).detach().cpu().numpy()
    conv = revert_sync_batchnorm(conv)
    y_bn = conv(x).detach().cpu().numpy()
    assert np.all(np.isclose(y_bn, y_mmsyncbn, 1e-3))
    conv, x = conv.to('cpu'), x.to('cpu')
    y_bn_cpu = conv(x).detach().numpy()
    assert np.all(np.isclose(y_bn, y_bn_cpu, 1e-3))