distribution.py 2.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright (c) OpenMMLab. All rights reserved.


def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
    """Wrap module in non-distributed environment by device type.

    - For CUDA, wrap as :obj:`mmcv.parallel.MMDataParallel`.
    - For MPS, wrap as :obj:`mmcv.device.mps.MPSDataParallel`.
    - For CPU & IPU, not wrap the model.

    Args:
        model(:class:`nn.Module`): model to be parallelized.
        device(str): device type, cuda, cpu or mlu. Defaults to cuda.
        dim(int): Dimension used to scatter the data. Defaults to 0.

    Returns:
        model(nn.Module): the model to be parallelized.
    """
    if device == 'npu':
        from mmcv.device.npu import NPUDataParallel
        model = NPUDataParallel(model.npu(), dim=dim, *args, **kwargs)
    elif device == 'cuda':
        from mmcv.parallel import MMDataParallel
        model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs)
    elif device == 'cpu':
        model = model.cpu()
    elif device == 'ipu':
        model = model.cpu()
    elif device == 'mps':
        from mmcv.device import mps
        model = mps.MPSDataParallel(model.to('mps'), dim=dim, *args, **kwargs)
    else:
        raise RuntimeError(f'Unavailable device "{device}"')

    return model


def wrap_distributed_model(model, device='cuda', *args, **kwargs):
    """Build DistributedDataParallel module by device type.

    - For CUDA, wrap as :obj:`mmcv.parallel.MMDistributedDataParallel`.
    - Other device types are not supported by now.

    Args:
        model(:class:`nn.Module`): module to be parallelized.
        device(str): device type, mlu or cuda.

    Returns:
        model(:class:`nn.Module`): the module to be parallelized

    References:
        .. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
               DistributedDataParallel.html
    """
    if device == 'npu':
        from mmcv.device.npu import NPUDistributedDataParallel
        from torch.npu import current_device
        model = NPUDistributedDataParallel(
            model.npu(), *args, device_ids=[current_device()], **kwargs)
    elif device == 'cuda':
        from mmcv.parallel import MMDistributedDataParallel
        from torch.cuda import current_device
        model = MMDistributedDataParallel(
            model.cuda(), *args, device_ids=[current_device()], **kwargs)
    else:
        raise RuntimeError(f'Unavailable device "{device}"')

    return model