utils.py 4.84 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
import torch
from diffusers import FluxPipeline
from peft.tuners import lora
Muyang Li's avatar
Muyang Li committed
4
from vars import LORA_PATHS, SVDQ_LORA_PATHS
Zhekai Zhang's avatar
Zhekai Zhang committed
5

muyangli's avatar
muyangli committed
6
from nunchaku import NunchakuFluxTransformer2dModel
7
from nunchaku.models.transformers.transformer_flux_v2 import NunchakuFluxTransformer2DModelV2
Zhekai Zhang's avatar
Zhekai Zhang committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25


def hash_str_to_int(s: str) -> int:
    """Hash a string to an integer."""
    modulus = 10**9 + 7  # Large prime modulus
    hash_int = 0
    for char in s:
        hash_int = (hash_int * 31 + ord(char)) % modulus
    return hash_int


def get_pipeline(
    model_name: str,
    precision: str,
    use_qencoder: bool = False,
    lora_name: str = "None",
    lora_weight: float = 1,
    device: str | torch.device = "cuda",
muyangli's avatar
muyangli committed
26
    pipeline_init_kwargs: dict = {},
Zhekai Zhang's avatar
Zhekai Zhang committed
27
28
) -> FluxPipeline:
    if model_name == "schnell":
29
        if precision in ["int4", "fp4"]:
Zhekai Zhang's avatar
Zhekai Zhang committed
30
            assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
31
            if precision == "int4":
32
33
34
                transformer = NunchakuFluxTransformer2dModel.from_pretrained(
                    "mit-han-lab/nunchaku-flux.1-schnell/svdq-int4_r32-flux.1-schnell.safetensors"
                )
35
36
37
            else:
                assert precision == "fp4"
                transformer = NunchakuFluxTransformer2dModel.from_pretrained(
38
                    "mit-han-lab/nunchaku-flux.1-schnell/svdq-fp4_r32-flux.1-schnell.safetensors", precision="fp4"
39
                )
40
41
            pipeline_init_kwargs["transformer"] = transformer
            if use_qencoder:
muyangli's avatar
muyangli committed
42
                from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
43

44
45
46
                text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
                    "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
                )
47
                pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
Zhekai Zhang's avatar
Zhekai Zhang committed
48
49
        else:
            assert precision == "bf16"
50
51
52
        pipeline = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
        )
53
54
55
56
57
58
59
60
61
62
    elif model_name == "schnell_v2":
        transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
            f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors"
        )
        pipeline = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell",
            transformer=transformer,
            torch_dtype=torch.bfloat16,
            **pipeline_init_kwargs,
        )
Zhekai Zhang's avatar
Zhekai Zhang committed
63
64
    elif model_name == "dev":
        if precision == "int4":
65
66
67
            transformer = NunchakuFluxTransformer2dModel.from_pretrained(
                "mit-han-lab/nunchaku-flux.1-dev/svdq-int4_r32-flux.1-dev.safetensors"
            )
Zhekai Zhang's avatar
Zhekai Zhang committed
68
            if lora_name not in ["All", "None"]:
69
70
71
72
                transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name])
                transformer.set_lora_strength(lora_weight)
            pipeline_init_kwargs["transformer"] = transformer
            if use_qencoder:
muyangli's avatar
muyangli committed
73
                from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
74

75
76
77
                text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
                    "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
                )
78
79
80
81
                pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
            pipeline = FluxPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
            )
Zhekai Zhang's avatar
Zhekai Zhang committed
82
83
        else:
            assert precision == "bf16"
muyangli's avatar
muyangli committed
84
85
86
            pipeline = FluxPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
            )
Zhekai Zhang's avatar
Zhekai Zhang committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            if lora_name == "All":
                # Pre-load all the LoRA weights for demo use
                for name, path in LORA_PATHS.items():
                    pipeline.load_lora_weights(path["name_or_path"], weight_name=path["weight_name"], adapter_name=name)
                for m in pipeline.transformer.modules():
                    if isinstance(m, lora.LoraLayer):
                        m.set_adapter(m.scaling.keys())
                        for name in m.scaling.keys():
                            m.scaling[name] = 0
            elif lora_name != "None":
                path = LORA_PATHS[lora_name]
                pipeline.load_lora_weights(
                    path["name_or_path"], weight_name=path["weight_name"], adapter_name=lora_name
                )
                for m in pipeline.transformer.modules():
                    if isinstance(m, lora.LoraLayer):
                        for name in m.scaling.keys():
                            m.scaling[name] = lora_weight
    else:
        raise NotImplementedError(f"Model {model_name} not implemented")
107
108
109
110
    if precision == "bf16":
        pipeline.enable_model_cpu_offload()
    else:
        pipeline = pipeline.to(device)
Zhekai Zhang's avatar
Zhekai Zhang committed
111
112

    return pipeline