sd3_clip.py 4.2 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from comfy import sd1_clip
from comfy import sdxl_clip
from transformers import T5TokenizerFast
import comfy.t5
import torch
import os
import comfy.model_management

class T5XXLModel(sd1_clip.SDClipModel):
    def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
        textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
        super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5)

class T5XXLTokenizer(sd1_clip.SDTokenizer):
    def __init__(self, embedding_directory=None):
        tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
        super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)

class SDT5XXLTokenizer(sd1_clip.SD1Tokenizer):
    def __init__(self, embedding_directory=None):
        super().__init__(embedding_directory=embedding_directory, clip_name="t5xxl", tokenizer=T5XXLTokenizer)

class SDT5XXLModel(sd1_clip.SD1ClipModel):
    def __init__(self, device="cpu", dtype=None, **kwargs):
        super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, **kwargs)



class SD3Tokenizer:
    def __init__(self, embedding_directory=None):
        self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
        self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
        self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)

    def tokenize_with_weights(self, text:str, return_word_ids=False):
        out = {}
        out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
        out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
        out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
        return out

    def untokenize(self, token_weight_pair):
        return self.clip_g.untokenize(token_weight_pair)

class SD3ClipModel(torch.nn.Module):
    def __init__(self, device="cpu", dtype=None):
        super().__init__()
        self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
        self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
        self.t5xxl = T5XXLModel(device=device, dtype=dtype)

    def set_clip_options(self, options):
        self.clip_l.set_clip_options(options)
        self.clip_g.set_clip_options(options)
        self.t5xxl.set_clip_options(options)

    def reset_clip_options(self):
        self.clip_g.reset_clip_options()
        self.clip_l.reset_clip_options()
        self.t5xxl.reset_clip_options()

    def encode_token_weights(self, token_weight_pairs):
        token_weight_pairs_l = token_weight_pairs["l"]
        token_weight_pairs_g = token_weight_pairs["g"]
        token_weight_pars_t5 = token_weight_pairs["t5xxl"]
        lg_out = None
        if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
            l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
            g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
            lg_out = torch.cat([l_out, g_out], dim=-1)
            lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
            out = lg_out
            pooled = torch.cat((l_pooled, g_pooled), dim=-1)
        else:
            pooled = torch.zeros((1, 1280 + 768), device=comfy.model_management.intermediate_device())

        t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
        if lg_out is not None:
            out = torch.cat([lg_out, t5_out], dim=-2)
        else:
            out = t5_out

        return out, pooled

    def load_sd(self, sd):
        if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
            return self.clip_g.load_sd(sd)
        elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
            return self.clip_l.load_sd(sd)
        else:
            return self.t5xxl.load_sd(sd)