utils.py 4.24 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
import torch
from diffusers import FluxPipeline
from peft.tuners import lora
Muyang Li's avatar
Muyang Li committed
4
from vars import LORA_PATHS, SVDQ_LORA_PATHS
Zhekai Zhang's avatar
Zhekai Zhang committed
5

muyangli's avatar
muyangli committed
6
from nunchaku import NunchakuFluxTransformer2dModel
Zhekai Zhang's avatar
Zhekai Zhang committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24


def hash_str_to_int(s: str) -> int:
    """Hash a string to an integer."""
    modulus = 10**9 + 7  # Large prime modulus
    hash_int = 0
    for char in s:
        hash_int = (hash_int * 31 + ord(char)) % modulus
    return hash_int


def get_pipeline(
    model_name: str,
    precision: str,
    use_qencoder: bool = False,
    lora_name: str = "None",
    lora_weight: float = 1,
    device: str | torch.device = "cuda",
muyangli's avatar
muyangli committed
25
    pipeline_init_kwargs: dict = {},
Zhekai Zhang's avatar
Zhekai Zhang committed
26
27
) -> FluxPipeline:
    if model_name == "schnell":
28
        if precision in ["int4", "fp4"]:
Zhekai Zhang's avatar
Zhekai Zhang committed
29
            assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
30
            if precision == "int4":
31
32
33
                transformer = NunchakuFluxTransformer2dModel.from_pretrained(
                    "mit-han-lab/nunchaku-flux.1-schnell/svdq-int4_r32-flux.1-schnell.safetensors"
                )
34
35
36
            else:
                assert precision == "fp4"
                transformer = NunchakuFluxTransformer2dModel.from_pretrained(
37
                    "mit-han-lab/nunchaku-flux.1-schnell/svdq-fp4_r32-flux.1-schnell.safetensors", precision="fp4"
38
                )
39
40
            pipeline_init_kwargs["transformer"] = transformer
            if use_qencoder:
muyangli's avatar
muyangli committed
41
                from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
42

43
44
45
                text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
                    "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
                )
46
                pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
Zhekai Zhang's avatar
Zhekai Zhang committed
47
48
        else:
            assert precision == "bf16"
49
50
51
        pipeline = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
        )
Zhekai Zhang's avatar
Zhekai Zhang committed
52
53
    elif model_name == "dev":
        if precision == "int4":
54
55
56
            transformer = NunchakuFluxTransformer2dModel.from_pretrained(
                "mit-han-lab/nunchaku-flux.1-dev/svdq-int4_r32-flux.1-dev.safetensors"
            )
Zhekai Zhang's avatar
Zhekai Zhang committed
57
            if lora_name not in ["All", "None"]:
58
59
60
61
                transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name])
                transformer.set_lora_strength(lora_weight)
            pipeline_init_kwargs["transformer"] = transformer
            if use_qencoder:
muyangli's avatar
muyangli committed
62
                from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
63

64
65
66
                text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
                    "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
                )
67
68
69
70
                pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
            pipeline = FluxPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
            )
Zhekai Zhang's avatar
Zhekai Zhang committed
71
72
        else:
            assert precision == "bf16"
muyangli's avatar
muyangli committed
73
74
75
            pipeline = FluxPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
            )
Zhekai Zhang's avatar
Zhekai Zhang committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
            if lora_name == "All":
                # Pre-load all the LoRA weights for demo use
                for name, path in LORA_PATHS.items():
                    pipeline.load_lora_weights(path["name_or_path"], weight_name=path["weight_name"], adapter_name=name)
                for m in pipeline.transformer.modules():
                    if isinstance(m, lora.LoraLayer):
                        m.set_adapter(m.scaling.keys())
                        for name in m.scaling.keys():
                            m.scaling[name] = 0
            elif lora_name != "None":
                path = LORA_PATHS[lora_name]
                pipeline.load_lora_weights(
                    path["name_or_path"], weight_name=path["weight_name"], adapter_name=lora_name
                )
                for m in pipeline.transformer.modules():
                    if isinstance(m, lora.LoraLayer):
                        for name in m.scaling.keys():
                            m.scaling[name] = lora_weight
    else:
        raise NotImplementedError(f"Model {model_name} not implemented")
    pipeline = pipeline.to(device)

    return pipeline