utils.py 3.2 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from diffusers import FluxPipeline
from peft.tuners import lora

from nunchaku.pipelines import flux as nunchaku_flux
from vars import LORA_PATHS, SVDQ_LORA_PATHS


def hash_str_to_int(s: str) -> int:
    """Hash a string to an integer."""
    modulus = 10**9 + 7  # Large prime modulus
    hash_int = 0
    for char in s:
        hash_int = (hash_int * 31 + ord(char)) % modulus
    return hash_int


def get_pipeline(
    model_name: str,
    precision: str,
    use_qencoder: bool = False,
    lora_name: str = "None",
    lora_weight: float = 1,
    device: str | torch.device = "cuda",
) -> FluxPipeline:
    if model_name == "schnell":
        if precision == "int4":
            assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
            pipeline = nunchaku_flux.from_pretrained(
                "black-forest-labs/FLUX.1-schnell",
                torch_dtype=torch.bfloat16,
32
                qmodel_path="mit-han-lab/svdq-int4-flux.1-schnell",
Zhekai Zhang's avatar
Zhekai Zhang committed
33
34
35
36
37
38
39
40
41
42
43
                qencoder_path="mit-han-lab/svdquant-models/svdq-w4a16-t5.pt" if use_qencoder else None,
                qmodel_device=device,
            )
        else:
            assert precision == "bf16"
            pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
    elif model_name == "dev":
        if precision == "int4":
            pipeline = nunchaku_flux.from_pretrained(
                "black-forest-labs/FLUX.1-dev",
                torch_dtype=torch.bfloat16,
44
                qmodel_path="mit-han-lab/svdq-int4-flux.1-dev",
Zhekai Zhang's avatar
Zhekai Zhang committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
                qencoder_path="mit-han-lab/svdquant-models/svdq-w4a16-t5.pt" if use_qencoder else None,
                qmodel_device=device,
            )
            if lora_name not in ["All", "None"]:
                pipeline.transformer.nunchaku_update_params(SVDQ_LORA_PATHS[lora_name])
                pipeline.transformer.nunchaku_set_lora_scale(lora_weight)
        else:
            assert precision == "bf16"
            pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
            if lora_name == "All":
                # Pre-load all the LoRA weights for demo use
                for name, path in LORA_PATHS.items():
                    pipeline.load_lora_weights(path["name_or_path"], weight_name=path["weight_name"], adapter_name=name)
                for m in pipeline.transformer.modules():
                    if isinstance(m, lora.LoraLayer):
                        m.set_adapter(m.scaling.keys())
                        for name in m.scaling.keys():
                            m.scaling[name] = 0
            elif lora_name != "None":
                path = LORA_PATHS[lora_name]
                pipeline.load_lora_weights(
                    path["name_or_path"], weight_name=path["weight_name"], adapter_name=lora_name
                )
                for m in pipeline.transformer.modules():
                    if isinstance(m, lora.LoraLayer):
                        for name in m.scaling.keys():
                            m.scaling[name] = lora_weight
    else:
        raise NotImplementedError(f"Model {model_name} not implemented")
    pipeline = pipeline.to(device)

    return pipeline