"lm_eval/evaluator_judge.py" did not exist on "5ccd65d4173a48bfdbc52603f82f7bf636973eb7"
utils.py 1.37 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from diffusers import SanaPAGPipeline

from nunchaku.models.transformer_sana import NunchakuSanaTransformer2DModel


def hash_str_to_int(s: str) -> int:
    """Hash a string to an integer."""
    modulus = 10**9 + 7  # Large prime modulus
    hash_int = 0
    for char in s:
        hash_int = (hash_int * 31 + ord(char)) % modulus
    return hash_int


def get_pipeline(
    precision: str, use_qencoder: bool = False, device: str | torch.device = "cuda", pipeline_init_kwargs: dict = {}
) -> SanaPAGPipeline:
    if precision == "int4":
        assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
        transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8)

        pipeline_init_kwargs["transformer"] = transformer
        if use_qencoder:
            raise NotImplementedError("Quantized encoder not supported for Sana for now")
    else:
        assert precision == "bf16"
    pipeline = SanaPAGPipeline.from_pretrained(
        "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
        variant="bf16",
        torch_dtype=torch.bfloat16,
        pag_applied_layers="transformer_blocks.8",
        **pipeline_init_kwargs
    )
    if precision == "int4":
        pipeline._set_pag_attn_processor = lambda *args, **kwargs: None

    pipeline = pipeline.to(device)
    return pipeline