utils.py 3.65 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
import torch
from diffusers import FluxPipeline
from peft.tuners import lora

5
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
Zhekai Zhang's avatar
Zhekai Zhang committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from vars import LORA_PATHS, SVDQ_LORA_PATHS


def hash_str_to_int(s: str) -> int:
    """Hash a string to an integer."""
    modulus = 10**9 + 7  # Large prime modulus
    hash_int = 0
    for char in s:
        hash_int = (hash_int * 31 + ord(char)) % modulus
    return hash_int


def get_pipeline(
    model_name: str,
    precision: str,
    use_qencoder: bool = False,
    lora_name: str = "None",
    lora_weight: float = 1,
    device: str | torch.device = "cuda",
muyangli's avatar
muyangli committed
25
    pipeline_init_kwargs: dict = {},
Zhekai Zhang's avatar
Zhekai Zhang committed
26
27
28
29
) -> FluxPipeline:
    if model_name == "schnell":
        if precision == "int4":
            assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
30
31
32
33
34
35
36
            transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
            pipeline_init_kwargs["transformer"] = transformer
            if use_qencoder:
                from nunchaku.models.text_encoder import NunchakuT5EncoderModel

                text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
                pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
Zhekai Zhang's avatar
Zhekai Zhang committed
37
38
        else:
            assert precision == "bf16"
39
40
41
        pipeline = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
        )
Zhekai Zhang's avatar
Zhekai Zhang committed
42
43
    elif model_name == "dev":
        if precision == "int4":
44
            transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
Zhekai Zhang's avatar
Zhekai Zhang committed
45
            if lora_name not in ["All", "None"]:
46
47
48
49
50
51
52
53
54
55
56
                transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name])
                transformer.set_lora_strength(lora_weight)
            pipeline_init_kwargs["transformer"] = transformer
            if use_qencoder:
                from nunchaku.models.text_encoder import NunchakuT5EncoderModel

                text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
                pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
            pipeline = FluxPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
            )
Zhekai Zhang's avatar
Zhekai Zhang committed
57
58
        else:
            assert precision == "bf16"
muyangli's avatar
muyangli committed
59
60
61
            pipeline = FluxPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
            )
Zhekai Zhang's avatar
Zhekai Zhang committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
            if lora_name == "All":
                # Pre-load all the LoRA weights for demo use
                for name, path in LORA_PATHS.items():
                    pipeline.load_lora_weights(path["name_or_path"], weight_name=path["weight_name"], adapter_name=name)
                for m in pipeline.transformer.modules():
                    if isinstance(m, lora.LoraLayer):
                        m.set_adapter(m.scaling.keys())
                        for name in m.scaling.keys():
                            m.scaling[name] = 0
            elif lora_name != "None":
                path = LORA_PATHS[lora_name]
                pipeline.load_lora_weights(
                    path["name_or_path"], weight_name=path["weight_name"], adapter_name=lora_name
                )
                for m in pipeline.transformer.modules():
                    if isinstance(m, lora.LoraLayer):
                        for name in m.scaling.keys():
                            m.scaling[name] = lora_weight
    else:
        raise NotImplementedError(f"Model {model_name} not implemented")
    pipeline = pipeline.to(device)

    return pipeline