utils.py 4.92 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
import torch
from diffusers import FluxPipeline
from peft.tuners import lora
Muyang Li's avatar
Muyang Li committed
4
from vars import LORA_PATHS, SVDQ_LORA_PATHS
Zhekai Zhang's avatar
Zhekai Zhang committed
5

muyangli's avatar
muyangli committed
6
from nunchaku import NunchakuFluxTransformer2dModel
7
from nunchaku.models.transformers.transformer_flux_v2 import NunchakuFluxTransformer2DModelV2
Zhekai Zhang's avatar
Zhekai Zhang committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25


def hash_str_to_int(s: str) -> int:
    """Hash a string to an integer."""
    modulus = 10**9 + 7  # Large prime modulus
    hash_int = 0
    for char in s:
        hash_int = (hash_int * 31 + ord(char)) % modulus
    return hash_int


def get_pipeline(
    model_name: str,
    precision: str,
    use_qencoder: bool = False,
    lora_name: str = "None",
    lora_weight: float = 1,
    device: str | torch.device = "cuda",
muyangli's avatar
muyangli committed
26
    pipeline_init_kwargs: dict = {},
Zhekai Zhang's avatar
Zhekai Zhang committed
27
28
) -> FluxPipeline:
    if model_name == "schnell":
29
        if precision in ["int4", "fp4"]:
Zhekai Zhang's avatar
Zhekai Zhang committed
30
            assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
31
            if precision == "int4":
32
                transformer = NunchakuFluxTransformer2dModel.from_pretrained(
33
                    "nunchaku-tech/nunchaku-flux.1-schnell/svdq-int4_r32-flux.1-schnell.safetensors"
34
                )
35
36
37
            else:
                assert precision == "fp4"
                transformer = NunchakuFluxTransformer2dModel.from_pretrained(
38
                    "nunchaku-tech/nunchaku-flux.1-schnell/svdq-fp4_r32-flux.1-schnell.safetensors", precision="fp4"
39
                )
40
            transformer.set_attention_impl("nunchaku-fp16")
41
42
            pipeline_init_kwargs["transformer"] = transformer
            if use_qencoder:
muyangli's avatar
muyangli committed
43
                from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
44

45
                text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
46
                    "nunchaku-tech/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
47
                )
48
                pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
Zhekai Zhang's avatar
Zhekai Zhang committed
49
50
        else:
            assert precision == "bf16"
51
52
53
        pipeline = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
        )
54
55
    elif model_name == "schnell_v2":
        transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
56
            f"nunchaku-tech/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors"
57
58
59
60
61
62
63
        )
        pipeline = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell",
            transformer=transformer,
            torch_dtype=torch.bfloat16,
            **pipeline_init_kwargs,
        )
Zhekai Zhang's avatar
Zhekai Zhang committed
64
65
    elif model_name == "dev":
        if precision == "int4":
66
            transformer = NunchakuFluxTransformer2dModel.from_pretrained(
67
                "nunchaku-tech/nunchaku-flux.1-dev/svdq-int4_r32-flux.1-dev.safetensors"
68
            )
Zhekai Zhang's avatar
Zhekai Zhang committed
69
            if lora_name not in ["All", "None"]:
70
71
72
73
                transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name])
                transformer.set_lora_strength(lora_weight)
            pipeline_init_kwargs["transformer"] = transformer
            if use_qencoder:
muyangli's avatar
muyangli committed
74
                from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
75

76
                text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
77
                    "nunchaku-tech/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
78
                )
79
80
81
82
                pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
            pipeline = FluxPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
            )
Zhekai Zhang's avatar
Zhekai Zhang committed
83
84
        else:
            assert precision == "bf16"
muyangli's avatar
muyangli committed
85
86
87
            pipeline = FluxPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
            )
Zhekai Zhang's avatar
Zhekai Zhang committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
            if lora_name == "All":
                # Pre-load all the LoRA weights for demo use
                for name, path in LORA_PATHS.items():
                    pipeline.load_lora_weights(path["name_or_path"], weight_name=path["weight_name"], adapter_name=name)
                for m in pipeline.transformer.modules():
                    if isinstance(m, lora.LoraLayer):
                        m.set_adapter(m.scaling.keys())
                        for name in m.scaling.keys():
                            m.scaling[name] = 0
            elif lora_name != "None":
                path = LORA_PATHS[lora_name]
                pipeline.load_lora_weights(
                    path["name_or_path"], weight_name=path["weight_name"], adapter_name=lora_name
                )
                for m in pipeline.transformer.modules():
                    if isinstance(m, lora.LoraLayer):
                        for name in m.scaling.keys():
                            m.scaling[name] = lora_weight
    else:
        raise NotImplementedError(f"Model {model_name} not implemented")
108
109
110
111
    if precision == "bf16":
        pipeline.enable_model_cpu_offload()
    else:
        pipeline = pipeline.to(device)
Zhekai Zhang's avatar
Zhekai Zhang committed
112
113

    return pipeline