utils.py 3.95 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
import torch
from diffusers import FluxPipeline
from peft.tuners import lora

5
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
Zhekai Zhang's avatar
Zhekai Zhang committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from vars import LORA_PATHS, SVDQ_LORA_PATHS


def hash_str_to_int(s: str) -> int:
    """Hash a string to an integer."""
    modulus = 10**9 + 7  # Large prime modulus
    hash_int = 0
    for char in s:
        hash_int = (hash_int * 31 + ord(char)) % modulus
    return hash_int


def get_pipeline(
    model_name: str,
    precision: str,
    use_qencoder: bool = False,
    lora_name: str = "None",
    lora_weight: float = 1,
    device: str | torch.device = "cuda",
muyangli's avatar
muyangli committed
25
    pipeline_init_kwargs: dict = {},
Zhekai Zhang's avatar
Zhekai Zhang committed
26
27
) -> FluxPipeline:
    if model_name == "schnell":
28
        if precision in ["int4", "fp4"]:
Zhekai Zhang's avatar
Zhekai Zhang committed
29
            assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
30
31
32
33
34
35
36
            if precision == "int4":
                transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
            else:
                assert precision == "fp4"
                transformer = NunchakuFluxTransformer2dModel.from_pretrained(
                    "/home/muyang/nunchaku_models/flux.1-schnell-nvfp4-svdq-gptq", precision="fp4"
                )
37
38
39
40
41
42
            pipeline_init_kwargs["transformer"] = transformer
            if use_qencoder:
                from nunchaku.models.text_encoder import NunchakuT5EncoderModel

                text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
                pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
Zhekai Zhang's avatar
Zhekai Zhang committed
43
44
        else:
            assert precision == "bf16"
45
46
47
        pipeline = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
        )
Zhekai Zhang's avatar
Zhekai Zhang committed
48
49
    elif model_name == "dev":
        if precision == "int4":
50
            transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
Zhekai Zhang's avatar
Zhekai Zhang committed
51
            if lora_name not in ["All", "None"]:
52
53
54
55
56
57
58
59
60
61
62
                transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name])
                transformer.set_lora_strength(lora_weight)
            pipeline_init_kwargs["transformer"] = transformer
            if use_qencoder:
                from nunchaku.models.text_encoder import NunchakuT5EncoderModel

                text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
                pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
            pipeline = FluxPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
            )
Zhekai Zhang's avatar
Zhekai Zhang committed
63
64
        else:
            assert precision == "bf16"
muyangli's avatar
muyangli committed
65
66
67
            pipeline = FluxPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
            )
Zhekai Zhang's avatar
Zhekai Zhang committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
            if lora_name == "All":
                # Pre-load all the LoRA weights for demo use
                for name, path in LORA_PATHS.items():
                    pipeline.load_lora_weights(path["name_or_path"], weight_name=path["weight_name"], adapter_name=name)
                for m in pipeline.transformer.modules():
                    if isinstance(m, lora.LoraLayer):
                        m.set_adapter(m.scaling.keys())
                        for name in m.scaling.keys():
                            m.scaling[name] = 0
            elif lora_name != "None":
                path = LORA_PATHS[lora_name]
                pipeline.load_lora_weights(
                    path["name_or_path"], weight_name=path["weight_name"], adapter_name=lora_name
                )
                for m in pipeline.transformer.modules():
                    if isinstance(m, lora.LoraLayer):
                        for name in m.scaling.keys():
                            m.scaling[name] = lora_weight
    else:
        raise NotImplementedError(f"Model {model_name} not implemented")
    pipeline = pipeline.to(device)

    return pipeline