ops.py 8.79 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
"""
    This file is part of ComfyUI.
    Copyright (C) 2024 Stability AI

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""

comfyanonymous's avatar
comfyanonymous committed
19
import torch
20
21
import comfy.model_management

22
23
24
25

def cast_to(weight, dtype=None, device=None, non_blocking=False):
    return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)

26
def cast_to_input(weight, input, non_blocking=False):
27
28
29
30
31
32
33
34
    return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)

def cast_bias_weight(s, input=None, dtype=None, device=None):
    if input is not None:
        if dtype is None:
            dtype = input.dtype
        if device is None:
            device = input.device
35

36
    bias = None
37
    non_blocking = comfy.model_management.device_should_use_non_blocking(device)
38
    if s.bias is not None:
39
        bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
40
41
        if s.bias_function is not None:
            bias = s.bias_function(bias)
42
    weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
43
44
    if s.weight_function is not None:
        weight = s.weight_function(weight)
45
46
    return weight, bias

comfyanonymous's avatar
comfyanonymous committed
47
48
49
50
class CastWeightBiasOp:
    comfy_cast_weights = False
    weight_function = None
    bias_function = None
comfyanonymous's avatar
comfyanonymous committed
51

comfyanonymous's avatar
comfyanonymous committed
52
class disable_weight_init:
comfyanonymous's avatar
comfyanonymous committed
53
    class Linear(torch.nn.Linear, CastWeightBiasOp):
comfyanonymous's avatar
comfyanonymous committed
54
55
        def reset_parameters(self):
            return None
56

57
58
59
60
61
62
63
64
65
66
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.linear(input, weight, bias)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

67
68
69
70
71
72
73
74
75
76
77
78
79
80
    class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
        def reset_parameters(self):
            return None

        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return self._conv_forward(input, weight, bias)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
81
    class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
comfyanonymous's avatar
comfyanonymous committed
82
83
        def reset_parameters(self):
            return None
84

85
86
87
88
89
90
91
92
93
94
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return self._conv_forward(input, weight, bias)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
95
    class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
comfyanonymous's avatar
comfyanonymous committed
96
97
        def reset_parameters(self):
            return None
comfyanonymous's avatar
comfyanonymous committed
98

99
100
101
102
103
104
105
106
107
108
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return self._conv_forward(input, weight, bias)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
109
    class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
comfyanonymous's avatar
comfyanonymous committed
110
111
        def reset_parameters(self):
            return None
112

113
114
115
116
117
118
119
120
121
122
123
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)


comfyanonymous's avatar
comfyanonymous committed
124
    class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
comfyanonymous's avatar
comfyanonymous committed
125
126
        def reset_parameters(self):
            return None
127

128
        def forward_comfy_cast_weights(self, input):
comfyanonymous's avatar
comfyanonymous committed
129
130
131
132
133
            if self.weight is not None:
                weight, bias = cast_bias_weight(self, input)
            else:
                weight = None
                bias = None
134
135
136
137
138
139
140
141
            return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
142
    class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
comfyanonymous's avatar
comfyanonymous committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        def reset_parameters(self):
            return None

        def forward_comfy_cast_weights(self, input, output_size=None):
            num_spatial_dims = 2
            output_padding = self._output_padding(
                input, output_size, self.stride, self.padding, self.kernel_size,
                num_spatial_dims, self.dilation)

            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.conv_transpose2d(
                input, weight, bias, self.stride, self.padding,
                output_padding, self.groups, self.dilation)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
        def reset_parameters(self):
            return None

        def forward_comfy_cast_weights(self, input, output_size=None):
            num_spatial_dims = 1
            output_padding = self._output_padding(
                input, output_size, self.stride, self.padding, self.kernel_size,
                num_spatial_dims, self.dilation)

            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.conv_transpose1d(
                input, weight, bias, self.stride, self.padding,
                output_padding, self.groups, self.dilation)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

184
185
186
187
188
    class Embedding(torch.nn.Embedding, CastWeightBiasOp):
        def reset_parameters(self):
            self.bias = None
            return None

189
190
191
192
193
194
        def forward_comfy_cast_weights(self, input, out_dtype=None):
            output_dtype = out_dtype
            if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
                out_dtype = None
            weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
            return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
195
196
197
198
199

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
200
201
                if "out_dtype" in kwargs:
                    kwargs.pop("out_dtype")
202
203
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
204
205
206
207
208
209
210
211
    @classmethod
    def conv_nd(s, dims, *args, **kwargs):
        if dims == 2:
            return s.Conv2d(*args, **kwargs)
        elif dims == 3:
            return s.Conv3d(*args, **kwargs)
        else:
            raise ValueError(f"unsupported dimensions: {dims}")
212

213

comfyanonymous's avatar
comfyanonymous committed
214
215
class manual_cast(disable_weight_init):
    class Linear(disable_weight_init.Linear):
216
        comfy_cast_weights = True
217

218
219
220
    class Conv1d(disable_weight_init.Conv1d):
        comfy_cast_weights = True

comfyanonymous's avatar
comfyanonymous committed
221
    class Conv2d(disable_weight_init.Conv2d):
222
        comfy_cast_weights = True
223

comfyanonymous's avatar
comfyanonymous committed
224
    class Conv3d(disable_weight_init.Conv3d):
225
        comfy_cast_weights = True
226

comfyanonymous's avatar
comfyanonymous committed
227
    class GroupNorm(disable_weight_init.GroupNorm):
228
        comfy_cast_weights = True
229

comfyanonymous's avatar
comfyanonymous committed
230
    class LayerNorm(disable_weight_init.LayerNorm):
231
        comfy_cast_weights = True
comfyanonymous's avatar
comfyanonymous committed
232
233
234

    class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
        comfy_cast_weights = True
235
236
237

    class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
        comfy_cast_weights = True
238
239
240

    class Embedding(disable_weight_init.Embedding):
        comfy_cast_weights = True