ops.py 3.88 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
import torch
2
from contextlib import contextmanager
3
4
5
6
7
8
9
10
11
12
import comfy.model_management

def cast_bias_weight(s, input):
    bias = None
    non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
    if s.bias is not None:
        bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
    weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
    return weight, bias

comfyanonymous's avatar
comfyanonymous committed
13

comfyanonymous's avatar
comfyanonymous committed
14
15
class disable_weight_init:
    class Linear(torch.nn.Linear):
16
        comfy_cast_weights = False
comfyanonymous's avatar
comfyanonymous committed
17
18
        def reset_parameters(self):
            return None
19

20
21
22
23
24
25
26
27
28
29
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.linear(input, weight, bias)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
30
    class Conv2d(torch.nn.Conv2d):
31
        comfy_cast_weights = False
comfyanonymous's avatar
comfyanonymous committed
32
33
        def reset_parameters(self):
            return None
34

35
36
37
38
39
40
41
42
43
44
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return self._conv_forward(input, weight, bias)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
45
    class Conv3d(torch.nn.Conv3d):
46
        comfy_cast_weights = False
comfyanonymous's avatar
comfyanonymous committed
47
48
        def reset_parameters(self):
            return None
comfyanonymous's avatar
comfyanonymous committed
49

50
51
52
53
54
55
56
57
58
59
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return self._conv_forward(input, weight, bias)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
60
    class GroupNorm(torch.nn.GroupNorm):
61
        comfy_cast_weights = False
comfyanonymous's avatar
comfyanonymous committed
62
63
        def reset_parameters(self):
            return None
64

65
66
67
68
69
70
71
72
73
74
75
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)


comfyanonymous's avatar
comfyanonymous committed
76
    class LayerNorm(torch.nn.LayerNorm):
77
        comfy_cast_weights = False
comfyanonymous's avatar
comfyanonymous committed
78
79
        def reset_parameters(self):
            return None
80

81
82
83
84
85
86
87
88
89
90
        def forward_comfy_cast_weights(self, input):
            weight, bias = cast_bias_weight(self, input)
            return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)

        def forward(self, *args, **kwargs):
            if self.comfy_cast_weights:
                return self.forward_comfy_cast_weights(*args, **kwargs)
            else:
                return super().forward(*args, **kwargs)

comfyanonymous's avatar
comfyanonymous committed
91
92
93
94
95
96
97
98
    @classmethod
    def conv_nd(s, dims, *args, **kwargs):
        if dims == 2:
            return s.Conv2d(*args, **kwargs)
        elif dims == 3:
            return s.Conv3d(*args, **kwargs)
        else:
            raise ValueError(f"unsupported dimensions: {dims}")
99

100

comfyanonymous's avatar
comfyanonymous committed
101
102
class manual_cast(disable_weight_init):
    class Linear(disable_weight_init.Linear):
103
        comfy_cast_weights = True
104

comfyanonymous's avatar
comfyanonymous committed
105
    class Conv2d(disable_weight_init.Conv2d):
106
        comfy_cast_weights = True
107

comfyanonymous's avatar
comfyanonymous committed
108
    class Conv3d(disable_weight_init.Conv3d):
109
        comfy_cast_weights = True
110

comfyanonymous's avatar
comfyanonymous committed
111
    class GroupNorm(disable_weight_init.GroupNorm):
112
        comfy_cast_weights = True
113

comfyanonymous's avatar
comfyanonymous committed
114
    class LayerNorm(disable_weight_init.LayerNorm):
115
        comfy_cast_weights = True