ops.py 5.66 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
"""
    This file is part of ComfyUI.
    Copyright (C) 2024 Stability AI

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""

comfyanonymous's avatar
comfyanonymous committed
19
import torch
20
21
22
23
24
25
26
27
28
29
import comfy.model_management

def cast_bias_weight(s, input):
    bias = None
    non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
    if s.bias is not None:
        bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
    weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
    return weight, bias

comfyanonymous's avatar
comfyanonymous committed
30

comfyanonymous's avatar
comfyanonymous committed
31
32
class disable_weight_init:
    class Linear(torch.nn.Linear):
33
        comfy_cast_weights = False
comfyanonymous's avatar
comfyanonymous committed
34
35
        def reset_parameters(self):
            return None
36

37
38
39
40
41
42
43
44
45
46
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.linear(input, weight, bias)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
47
    class Conv2d(torch.nn.Conv2d):
48
        comfy_cast_weights = False
comfyanonymous's avatar
comfyanonymous committed
49
50
        def reset_parameters(self):
            return None
51

52
53
54
55
56
57
58
59
60
61
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return self._conv_forward(input, weight, bias)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
62
    class Conv3d(torch.nn.Conv3d):
63
        comfy_cast_weights = False
comfyanonymous's avatar
comfyanonymous committed
64
65
        def reset_parameters(self):
            return None
comfyanonymous's avatar
comfyanonymous committed
66

67
68
69
70
71
72
73
74
75
76
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return self._conv_forward(input, weight, bias)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
77
    class GroupNorm(torch.nn.GroupNorm):
78
        comfy_cast_weights = False
comfyanonymous's avatar
comfyanonymous committed
79
80
        def reset_parameters(self):
            return None
81

82
83
84
85
86
87
88
89
90
91
92
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)


comfyanonymous's avatar
comfyanonymous committed
93
    class LayerNorm(torch.nn.LayerNorm):
94
        comfy_cast_weights = False
comfyanonymous's avatar
comfyanonymous committed
95
96
        def reset_parameters(self):
            return None
97

98
        def forward_comfy_cast_weights(self, input):
comfyanonymous's avatar
comfyanonymous committed
99
100
101
102
103
            if self.weight is not None:
                weight, bias = cast_bias_weight(self, input)
            else:
                weight = None
                bias = None
104
105
106
107
108
109
110
111
            return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    class ConvTranspose2d(torch.nn.ConvTranspose2d):
        comfy_cast_weights = False
        def reset_parameters(self):
            return None

        def forward_comfy_cast_weights(self, input, output_size=None):
            num_spatial_dims = 2
            output_padding = self._output_padding(
                input, output_size, self.stride, self.padding, self.kernel_size,
                num_spatial_dims, self.dilation)

            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.conv_transpose2d(
                input, weight, bias, self.stride, self.padding,
                output_padding, self.groups, self.dilation)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
134
135
136
137
138
139
140
141
    @classmethod
    def conv_nd(s, dims, *args, **kwargs):
        if dims == 2:
            return s.Conv2d(*args, **kwargs)
        elif dims == 3:
            return s.Conv3d(*args, **kwargs)
        else:
            raise ValueError(f"unsupported dimensions: {dims}")
142

143

comfyanonymous's avatar
comfyanonymous committed
144
145
class manual_cast(disable_weight_init):
    class Linear(disable_weight_init.Linear):
146
        comfy_cast_weights = True
147

comfyanonymous's avatar
comfyanonymous committed
148
    class Conv2d(disable_weight_init.Conv2d):
149
        comfy_cast_weights = True
150

comfyanonymous's avatar
comfyanonymous committed
151
    class Conv3d(disable_weight_init.Conv3d):
152
        comfy_cast_weights = True
153

comfyanonymous's avatar
comfyanonymous committed
154
    class GroupNorm(disable_weight_init.GroupNorm):
155
        comfy_cast_weights = True
156

comfyanonymous's avatar
comfyanonymous committed
157
    class LayerNorm(disable_weight_init.LayerNorm):
158
        comfy_cast_weights = True
comfyanonymous's avatar
comfyanonymous committed
159
160
161

    class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
        comfy_cast_weights = True