conv2d.py 5.09 KB
Newer Older
luopl's avatar
init  
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
from torch import distributed as dist
from torch import nn
from torch.nn import functional as F

from ..base_module import BaseModule
from ...utils import DistriConfig


class DistriConv2dPP(BaseModule):
    def __init__(self, module: nn.Conv2d, distri_config: DistriConfig, is_first_layer: bool = False):
        super(DistriConv2dPP, self).__init__(module, distri_config)
        self.is_first_layer = is_first_layer

    def naive_forward(self, x: torch.Tensor) -> torch.Tensor:
        #  x: [B, C, H, W]
        output = self.module(x)
        return output

    def sliced_forward(self, x: torch.Tensor) -> torch.Tensor:
        config = self.distri_config
        b, c, h, w = x.shape
        assert h % config.n_device_per_batch == 0

        stride = self.module.stride[0]
        padding = self.module.padding[0]

        output_h = x.shape[2] // stride // config.n_device_per_batch
        idx = config.split_idx()
        h_begin = output_h * idx * stride - padding
        h_end = output_h * (idx + 1) * stride + padding
        final_padding = [padding, padding, 0, 0]
        if h_begin < 0:
            h_begin = 0
            final_padding[2] = padding
        if h_end > h:
            h_end = h
            final_padding[3] = padding
        sliced_input = x[:, :, h_begin:h_end, :]
        padded_input = F.pad(sliced_input, final_padding, mode="constant")
        return F.conv2d(padded_input, self.module.weight, self.module.bias, stride=stride, padding="valid")

    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        distri_config = self.distri_config

        if self.comm_manager is not None and self.comm_manager.handles is not None and self.idx is not None:
            if self.comm_manager.handles[self.idx] is not None:
                self.comm_manager.handles[self.idx].wait()
                self.comm_manager.handles[self.idx] = None

        if distri_config.n_device_per_batch == 1:
            output = self.naive_forward(x)
        else:
            if self.is_first_layer:
                full_x = x
                output = self.sliced_forward(full_x)
            else:
                boundary_size = self.module.padding[0]
                if self.buffer_list is None:
                    if self.comm_manager.buffer_list is None:
                        self.idx = self.comm_manager.register_tensor(
                            shape=[2, x.shape[0], x.shape[1], boundary_size, x.shape[3]],
                            torch_dtype=x.dtype,
                            layer_type="conv2d",
                        )
                    else:
                        self.buffer_list = self.comm_manager.get_buffer_list(self.idx)
                if self.buffer_list is None:
                    output = self.naive_forward(x)
                else:

                    def create_padded_x():
                        if distri_config.split_idx() == 0:
                            concat_x = torch.cat([x, self.buffer_list[distri_config.split_idx() + 1][0]], dim=2)
                            padded_x = F.pad(concat_x, [0, 0, boundary_size, 0], mode="constant")
                        elif distri_config.split_idx() == distri_config.n_device_per_batch - 1:
                            concat_x = torch.cat([self.buffer_list[distri_config.split_idx() - 1][1], x], dim=2)
                            padded_x = F.pad(concat_x, [0, 0, 0, boundary_size], mode="constant")
                        else:
                            padded_x = torch.cat(
                                [
                                    self.buffer_list[distri_config.split_idx() - 1][1],
                                    x,
                                    self.buffer_list[distri_config.split_idx() + 1][0],
                                ],
                                dim=2,
                            )
                        return padded_x

                    boundary = torch.stack([x[:, :, :boundary_size, :], x[:, :, -boundary_size:, :]], dim=0)

                    if distri_config.mode == "full_sync" or self.counter <= distri_config.warmup_steps:
                        dist.all_gather(self.buffer_list, boundary, group=distri_config.batch_group, async_op=False)
                        padded_x = create_padded_x()
                        output = F.conv2d(
                            padded_x,
                            self.module.weight,
                            self.module.bias,
                            stride=self.module.stride[0],
                            padding=(0, self.module.padding[1]),
                        )
                    else:
                        padded_x = create_padded_x()
                        output = F.conv2d(
                            padded_x,
                            self.module.weight,
                            self.module.bias,
                            stride=self.module.stride[0],
                            padding=(0, self.module.padding[1]),
                        )
                        if distri_config.mode != "no_sync":
                            self.comm_manager.enqueue(self.idx, boundary)

        self.counter += 1
        return output