"examples/research_projects/README.md" did not exist on "86896de064f166b9ea347f139c9698c248c7cc4a"
conv2d.py 2.13 KB
Newer Older
luopl's avatar
init  
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from torch import distributed as dist
from torch import nn
from torch.nn import functional as F

from ..base_module import BaseModule
from ...utils import DistriConfig


class DistriConv2dTP(BaseModule):
    def __init__(self, module: nn.Conv2d, distri_config: DistriConfig):
        super(DistriConv2dTP, self).__init__(module, distri_config)
        assert module.in_channels % distri_config.n_device_per_batch == 0

        sharded_module = nn.Conv2d(
            module.in_channels // distri_config.n_device_per_batch,
            module.out_channels,
            module.kernel_size,
            module.stride,
            module.padding,
            module.dilation,
            module.groups,
            module.bias is not None,
            module.padding_mode,
            device=module.weight.device,
            dtype=module.weight.dtype,
        )
        start_idx = distri_config.split_idx() * (module.in_channels // distri_config.n_device_per_batch)
        end_idx = (distri_config.split_idx() + 1) * (module.in_channels // distri_config.n_device_per_batch)
        sharded_module.weight.data.copy_(module.weight.data[:, start_idx:end_idx])
        if module.bias is not None:
            sharded_module.bias.data.copy_(module.bias.data)

        self.module = sharded_module
        del module

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        distri_config = self.distri_config

        b, c, h, w = x.shape
        start_idx = distri_config.split_idx() * (c // distri_config.n_device_per_batch)
        end_idx = (distri_config.split_idx() + 1) * (c // distri_config.n_device_per_batch)
        output = F.conv2d(
            x[:, start_idx:end_idx],
            self.module.weight,
            bias=None,
            stride=self.module.stride,
            padding=self.module.padding,
            dilation=self.module.dilation,
            groups=self.module.groups,
        )
        dist.all_reduce(output, op=dist.ReduceOp.SUM, group=distri_config.batch_group, async_op=False)
        if self.module.bias is not None:
            output = output + self.module.bias.view(1, -1, 1, 1)

        self.counter += 1
        return output