ops.py 3.85 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
import torch
2
3
4
5
6
7
8
9
10
11
import comfy.model_management

def cast_bias_weight(s, input):
    bias = None
    non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
    if s.bias is not None:
        bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
    weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
    return weight, bias

comfyanonymous's avatar
comfyanonymous committed
12

comfyanonymous's avatar
comfyanonymous committed
13
14
class disable_weight_init:
    class Linear(torch.nn.Linear):
15
        comfy_cast_weights = False
comfyanonymous's avatar
comfyanonymous committed
16
17
        def reset_parameters(self):
            return None
18

19
20
21
22
23
24
25
26
27
28
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.linear(input, weight, bias)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
29
    class Conv2d(torch.nn.Conv2d):
30
        comfy_cast_weights = False
comfyanonymous's avatar
comfyanonymous committed
31
32
        def reset_parameters(self):
            return None
33

34
35
36
37
38
39
40
41
42
43
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return self._conv_forward(input, weight, bias)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
44
    class Conv3d(torch.nn.Conv3d):
45
        comfy_cast_weights = False
comfyanonymous's avatar
comfyanonymous committed
46
47
        def reset_parameters(self):
            return None
comfyanonymous's avatar
comfyanonymous committed
48

49
50
51
52
53
54
55
56
57
58
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return self._conv_forward(input, weight, bias)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
59
    class GroupNorm(torch.nn.GroupNorm):
60
        comfy_cast_weights = False
comfyanonymous's avatar
comfyanonymous committed
61
62
        def reset_parameters(self):
            return None
63

64
65
66
67
68
69
70
71
72
73
74
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)


comfyanonymous's avatar
comfyanonymous committed
75
    class LayerNorm(torch.nn.LayerNorm):
76
        comfy_cast_weights = False
comfyanonymous's avatar
comfyanonymous committed
77
78
        def reset_parameters(self):
            return None
79

80
81
82
83
84
85
86
87
88
89
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
90
91
92
93
94
95
96
97
    @classmethod
    def conv_nd(s, dims, *args, **kwargs):
        if dims == 2:
            return s.Conv2d(*args, **kwargs)
        elif dims == 3:
            return s.Conv3d(*args, **kwargs)
        else:
            raise ValueError(f"unsupported dimensions: {dims}")
98

99

comfyanonymous's avatar
comfyanonymous committed
100
101
class manual_cast(disable_weight_init):
    class Linear(disable_weight_init.Linear):
102
        comfy_cast_weights = True
103

comfyanonymous's avatar
comfyanonymous committed
104
    class Conv2d(disable_weight_init.Conv2d):
105
        comfy_cast_weights = True
106

comfyanonymous's avatar
comfyanonymous committed
107
    class Conv3d(disable_weight_init.Conv3d):
108
        comfy_cast_weights = True
109

comfyanonymous's avatar
comfyanonymous committed
110
    class GroupNorm(disable_weight_init.GroupNorm):
111
        comfy_cast_weights = True
112

comfyanonymous's avatar
comfyanonymous committed
113
    class LayerNorm(disable_weight_init.LayerNorm):
114
        comfy_cast_weights = True