sd3_clip.py 6.34 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
from comfy import sd1_clip
from comfy import sdxl_clip
from transformers import T5TokenizerFast
4
import comfy.text_encoders.t5
comfyanonymous's avatar
comfyanonymous committed
5
6
7
import torch
import os
import comfy.model_management
8
import logging
comfyanonymous's avatar
comfyanonymous committed
9
10
11
12

class T5XXLModel(sd1_clip.SDClipModel):
    def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
        textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
13
        super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
comfyanonymous's avatar
comfyanonymous committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

class T5XXLTokenizer(sd1_clip.SDTokenizer):
    def __init__(self, embedding_directory=None):
        tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
        super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)

class SDT5XXLTokenizer(sd1_clip.SD1Tokenizer):
    def __init__(self, embedding_directory=None):
        super().__init__(embedding_directory=embedding_directory, clip_name="t5xxl", tokenizer=T5XXLTokenizer)

class SDT5XXLModel(sd1_clip.SD1ClipModel):
    def __init__(self, device="cpu", dtype=None, **kwargs):
        super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, **kwargs)



class SD3Tokenizer:
    def __init__(self, embedding_directory=None):
        self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
        self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
        self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)

    def tokenize_with_weights(self, text:str, return_word_ids=False):
        out = {}
        out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
        out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
        out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
        return out

    def untokenize(self, token_weight_pair):
        return self.clip_g.untokenize(token_weight_pair)

class SD3ClipModel(torch.nn.Module):
47
    def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
comfyanonymous's avatar
comfyanonymous committed
48
        super().__init__()
49
        self.dtypes = set()
50
51
        if clip_l:
            self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
52
            self.dtypes.add(dtype)
53
54
55
56
57
        else:
            self.clip_l = None

        if clip_g:
            self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
58
            self.dtypes.add(dtype)
59
60
61
62
        else:
            self.clip_g = None

        if t5:
63
64
65
66
67
68
69
70
71
72
            if dtype_t5 is None:
                dtype_t5 = dtype
            elif comfy.model_management.dtype_size(dtype_t5) > comfy.model_management.dtype_size(dtype):
                dtype_t5 = dtype

            if not comfy.model_management.supports_cast(device, dtype_t5):
                dtype_t5 = dtype

            self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
            self.dtypes.add(dtype_t5)
73
74
75
        else:
            self.t5xxl = None

76
        logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}".format(clip_l, clip_g, t5, dtype_t5))
comfyanonymous's avatar
comfyanonymous committed
77
78

    def set_clip_options(self, options):
79
80
81
82
83
84
        if self.clip_l is not None:
            self.clip_l.set_clip_options(options)
        if self.clip_g is not None:
            self.clip_g.set_clip_options(options)
        if self.t5xxl is not None:
            self.t5xxl.set_clip_options(options)
comfyanonymous's avatar
comfyanonymous committed
85
86

    def reset_clip_options(self):
87
88
89
90
91
92
        if self.clip_l is not None:
            self.clip_l.reset_clip_options()
        if self.clip_g is not None:
            self.clip_g.reset_clip_options()
        if self.t5xxl is not None:
            self.t5xxl.reset_clip_options()
comfyanonymous's avatar
comfyanonymous committed
93
94
95
96
97
98

    def encode_token_weights(self, token_weight_pairs):
        token_weight_pairs_l = token_weight_pairs["l"]
        token_weight_pairs_g = token_weight_pairs["g"]
        token_weight_pars_t5 = token_weight_pairs["t5xxl"]
        lg_out = None
99
100
101
        pooled = None
        out = None

comfyanonymous's avatar
comfyanonymous committed
102
        if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
            if self.clip_l is not None:
                lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
            else:
                l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())

            if self.clip_g is not None:
                g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
                if lg_out is not None:
                    lg_out = torch.cat([lg_out, g_out], dim=-1)
                else:
                    lg_out = torch.nn.functional.pad(g_out, (768, 0))
            else:
                g_out = None
                g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())

            if lg_out is not None:
                lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
                out = lg_out
comfyanonymous's avatar
comfyanonymous committed
121
122
            pooled = torch.cat((l_pooled, g_pooled), dim=-1)

123
124
125
126
127
128
129
130
131
132
133
134
        if self.t5xxl is not None:
            t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
            if lg_out is not None:
                out = torch.cat([lg_out, t5_out], dim=-2)
            else:
                out = t5_out

        if out is None:
            out = torch.zeros((1, 77, 4096), device=comfy.model_management.intermediate_device())

        if pooled is None:
            pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
comfyanonymous's avatar
comfyanonymous committed
135
136
137
138
139
140
141
142
143
144

        return out, pooled

    def load_sd(self, sd):
        if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
            return self.clip_g.load_sd(sd)
        elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
            return self.clip_l.load_sd(sd)
        else:
            return self.t5xxl.load_sd(sd)
145
146
147
148
149
150

def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
    class SD3ClipModel_(SD3ClipModel):
        def __init__(self, device="cpu", dtype=None):
            super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
    return SD3ClipModel_