text_encoder.py 4.43 KB
Newer Older
1
2
3
4
5
6
import os
import types

import comfy.sd
import folder_paths
import torch
muyangli's avatar
muyangli committed
7
from torch import nn
8
9
from transformers import T5EncoderModel

muyangli's avatar
muyangli committed
10
11
from nunchaku import NunchakuT5EncoderModel

12
13
14
15
16

def svdquant_t5_forward(
    self: T5EncoderModel,
    input_ids: torch.LongTensor,
    attention_mask,
muyangli's avatar
muyangli committed
17
    embeds=None,
18
19
20
    intermediate_output=None,
    final_layer_norm_intermediate=True,
    dtype: str | torch.dtype = torch.bfloat16,
muyangli's avatar
muyangli committed
21
    **kwargs,
22
23
24
25
):
    assert attention_mask is None
    assert intermediate_output is None
    assert final_layer_norm_intermediate
muyangli's avatar
muyangli committed
26
    outputs = self.encoder(input_ids=input_ids, inputs_embeds=embeds, attention_mask=attention_mask)
27
28
29
30
31
    hidden_states = outputs["last_hidden_state"]
    hidden_states = hidden_states.to(dtype=dtype)
    return hidden_states, None


muyangli's avatar
muyangli committed
32
33
34
35
36
37
38
39
40
41
42
43
44
class WrappedEmbedding(nn.Module):
    def __init__(self, embedding: nn.Embedding):
        super().__init__()
        self.embedding = embedding

    def forward(self, input: torch.Tensor, out_dtype: torch.dtype | None = None):
        return self.embedding(input)

    @property
    def weight(self):
        return self.embedding.weight


45
46
47
48
class SVDQuantTextEncoderLoader:
    @classmethod
    def INPUT_TYPES(s):
        model_paths = ["mit-han-lab/svdq-flux.1-t5"]
muyangli's avatar
muyangli committed
49
50
51
52
53
54
55
56
57
58
59
60
        prefixes = folder_paths.folder_names_and_paths["text_encoders"][0]
        local_folders = set()
        for prefix in prefixes:
            if os.path.exists(prefix) and os.path.isdir(prefix):
                local_folders_ = os.listdir(prefix)
                local_folders_ = [
                    folder
                    for folder in local_folders_
                    if not folder.startswith(".") and os.path.isdir(os.path.join(prefix, folder))
                ]
                local_folders.update(local_folders_)
        local_folders = sorted(list(local_folders))
61
62
63
64
65
66
67
68
        model_paths.extend(local_folders)
        return {
            "required": {
                "model_type": (["flux"],),
                "text_encoder1": (folder_paths.get_filename_list("text_encoders"),),
                "text_encoder2": (folder_paths.get_filename_list("text_encoders"),),
                "t5_min_length": (
                    "INT",
muyangli's avatar
muyangli committed
69
                    {"default": 512, "min": 256, "max": 1024, "step": 128, "display": "number", "lazy": True},
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
                ),
                "t5_precision": (["BF16", "INT4"],),
                "int4_model": (model_paths, {"tooltip": "The name of the INT4 model."}),
            }
        }

    RETURN_TYPES = ("CLIP",)
    FUNCTION = "load_text_encoder"

    CATEGORY = "SVDQuant"

    TITLE = "SVDQuant Text Encoder Loader"

    def load_text_encoder(
        self,
        model_type: str,
        text_encoder1: str,
        text_encoder2: str,
        t5_min_length: int,
        t5_precision: str,
        int4_model: str,
    ):
        text_encoder_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder1)
        text_encoder_path2 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder2)
        if model_type == "flux":
            clip_type = comfy.sd.CLIPType.FLUX
        else:
            raise ValueError(f"Unknown type {model_type}")

        clip = comfy.sd.load_clip(
            ckpt_paths=[text_encoder_path1, text_encoder_path2],
            embedding_directory=folder_paths.get_folder_paths("embeddings"),
            clip_type=clip_type,
        )

        if model_type == "flux":
            clip.tokenizer.t5xxl.min_length = t5_min_length

        if t5_precision == "INT4":
            transformer = clip.cond_stage_model.t5xxl.transformer
            param = next(transformer.parameters())
            dtype = param.dtype
            device = param.device

muyangli's avatar
muyangli committed
114
115
116
117
118
119
120
            prefixes = folder_paths.folder_names_and_paths["diffusion_models"][0]
            model_path = None
            for prefix in prefixes:
                if os.path.exists(os.path.join(prefix, int4_model)):
                    model_path = os.path.join(prefix, int4_model)
                    break
            if model_path is None:
121
122
123
                model_path = int4_model
            transformer = NunchakuT5EncoderModel.from_pretrained(model_path)
            transformer.forward = types.MethodType(svdquant_t5_forward, transformer)
muyangli's avatar
muyangli committed
124
125
            transformer.shared = WrappedEmbedding(transformer.shared)

126
127
128
129
130
            clip.cond_stage_model.t5xxl.transformer = (
                transformer.to(device=device, dtype=dtype) if device.type == "cuda" else transformer
            )

        return (clip,)