resnet.py 7.45 KB
Newer Older
luopl's avatar
init  
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import torch.cuda
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D, USE_PEFT_BACKEND
from torch import distributed as dist
from torch import nn
from torch.nn import functional as F

from ..base_module import BaseModule
from ...utils import DistriConfig


class DistriResnetBlock2DTP(BaseModule):
    def __init__(self, module: ResnetBlock2D, distri_config: DistriConfig):
        super(DistriResnetBlock2DTP, self).__init__(module, distri_config)
        assert module.conv1.out_channels % distri_config.n_device_per_batch == 0

        mid_channels = module.conv1.out_channels // distri_config.n_device_per_batch

        sharded_conv1 = nn.Conv2d(
            module.conv1.in_channels,
            mid_channels,
            module.conv1.kernel_size,
            module.conv1.stride,
            module.conv1.padding,
            module.conv1.dilation,
            module.conv1.groups,
            module.conv1.bias is not None,
            module.conv1.padding_mode,
            device=module.conv1.weight.device,
            dtype=module.conv1.weight.dtype,
        )
        sharded_conv1.weight.data.copy_(
            module.conv1.weight.data[
                distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels
            ]
        )
        if module.conv1.bias is not None:
            sharded_conv1.bias.data.copy_(
                module.conv1.bias.data[
                    distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels
                ]
            )

        sharded_conv2 = nn.Conv2d(
            mid_channels,
            module.conv2.out_channels,
            module.conv2.kernel_size,
            module.conv2.stride,
            module.conv2.padding,
            module.conv2.dilation,
            module.conv2.groups,
            module.conv2.bias is not None,
            module.conv2.padding_mode,
            device=module.conv2.weight.device,
            dtype=module.conv2.weight.dtype,
        )
        sharded_conv2.weight.data.copy_(
            module.conv2.weight.data[
                :, distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels
            ]
        )
        if module.conv2.bias is not None:
            sharded_conv2.bias.data.copy_(module.conv2.bias.data)

        assert module.time_emb_proj is not None
        assert module.time_embedding_norm == "default"

        sharded_time_emb_proj = nn.Linear(
            module.time_emb_proj.in_features,
            mid_channels,
            bias=module.time_emb_proj.bias is not None,
            device=module.time_emb_proj.weight.device,
            dtype=module.time_emb_proj.weight.dtype,
        )
        sharded_time_emb_proj.weight.data.copy_(
            module.time_emb_proj.weight.data[
                distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels
            ]
        )
        if module.time_emb_proj.bias is not None:
            sharded_time_emb_proj.bias.data.copy_(
                module.time_emb_proj.bias.data[
                    distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels
                ]
            )

        sharded_norm2 = nn.GroupNorm(
            module.norm2.num_groups // distri_config.n_device_per_batch,
            mid_channels,
            module.norm2.eps,
            module.norm2.affine,
            device=module.norm2.weight.device,
            dtype=module.norm2.weight.dtype,
        )
        if module.norm2.affine:
            sharded_norm2.weight.data.copy_(
                module.norm2.weight.data[
                    distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels
                ]
            )
            sharded_norm2.bias.data.copy_(
                module.norm2.bias.data[
                    distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels
                ]
            )

        del module.conv1
        del module.conv2
        del module.time_emb_proj
        del module.norm2
        module.conv1 = sharded_conv1
        module.conv2 = sharded_conv2
        module.time_emb_proj = sharded_time_emb_proj
        module.norm2 = sharded_norm2

        torch.cuda.empty_cache()

    def forward(
        self,
        input_tensor: torch.FloatTensor,
        temb: torch.FloatTensor,
        scale: float = 1.0,
    ) -> torch.FloatTensor:
        assert scale == 1.0

        distri_config = self.distri_config
        module = self.module

        hidden_states = input_tensor
        hidden_states = module.norm1(hidden_states)

        hidden_states = module.nonlinearity(hidden_states)

        if module.upsample is not None:
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
            input_tensor = (
                module.upsample(input_tensor, scale=scale)
                if isinstance(module.upsample, Upsample2D)
                else module.upsample(input_tensor)
            )
            hidden_states = (
                module.upsample(hidden_states, scale=scale)
                if isinstance(module.upsample, Upsample2D)
                else module.upsample(hidden_states)
            )
        elif module.downsample is not None:
            input_tensor = (
                module.downsample(input_tensor, scale=scale)
                if isinstance(module.downsample, Downsample2D)
                else module.downsample(input_tensor)
            )
            hidden_states = (
                module.downsample(hidden_states, scale=scale)
                if isinstance(module.downsample, Downsample2D)
                else module.downsample(hidden_states)
            )

        hidden_states = module.conv1(hidden_states)

        if module.time_emb_proj is not None:
            if not module.skip_time_act:
                temb = module.nonlinearity(temb)
            temb = module.time_emb_proj(temb)[:, :, None, None]

        if temb is not None and module.time_embedding_norm == "default":
            hidden_states = hidden_states + temb

        hidden_states = module.norm2(hidden_states)

        if temb is not None and module.time_embedding_norm == "scale_shift":
            scale, shift = torch.chunk(temb, 2, dim=1)
            hidden_states = hidden_states * (1 + scale) + shift

        hidden_states = module.nonlinearity(hidden_states)

        hidden_states = module.dropout(hidden_states)
        hidden_states = F.conv2d(
            hidden_states,
            module.conv2.weight,
            bias=None,
            stride=module.conv2.stride,
            padding=module.conv2.padding,
            dilation=module.conv2.dilation,
            groups=module.conv2.groups,
        )

        dist.all_reduce(hidden_states, op=dist.ReduceOp.SUM, group=distri_config.batch_group, async_op=False)
        if module.conv2.bias is not None:
            hidden_states = hidden_states + module.conv2.bias.view(1, -1, 1, 1)

        if module.conv_shortcut is not None:
            input_tensor = (
                module.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
            )

        output_tensor = (input_tensor + hidden_states) / module.output_scale_factor

        self.counter += 1

        return output_tensor