test_shuttle_jaguar.py 3.37 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os

import pytest
import torch
from diffusers import FluxPipeline

from nunchaku import NunchakuFluxTransformer2dModel
from tests.data import get_dataset
from tests.flux.utils import run_pipeline
from tests.utils import already_generate, compute_lpips


@pytest.mark.parametrize(
    "precision,height,width,num_inference_steps,guidance_scale,use_qencoder,cpu_offload,max_dataset_size,expected_lpips",
    [
        ("int4", 1024, 1024, 4, 3.5, False, False, 16, 0.25),
        ("int4", 2048, 512, 4, 3.5, False, False, 16, 0.21),
    ],
)
def test_shuttle_jaguar(
    precision: str,
    height: int,
    width: int,
    num_inference_steps: int,
    guidance_scale: float,
    use_qencoder: bool,
    cpu_offload: bool,
    max_dataset_size: int,
    expected_lpips: float,
):
    dataset = get_dataset(name="MJHQ", max_dataset_size=max_dataset_size)
    save_root = os.path.join("results", "shuttle-jaguar", f"w{width}h{height}t{num_inference_steps}g{guidance_scale}")

    save_dir_16bit = os.path.join(save_root, "bf16")
    if not already_generate(save_dir_16bit, max_dataset_size):
        pipeline = FluxPipeline.from_pretrained("shuttleai/shuttle-jaguar", torch_dtype=torch.bfloat16)
        pipeline = pipeline.to("cuda")

        run_pipeline(
            dataset,
            pipeline,
            save_dir=save_dir_16bit,
            forward_kwargs={
                "height": height,
                "width": width,
                "num_inference_steps": num_inference_steps,
                "guidance_scale": guidance_scale,
            },
        )
        del pipeline
        # release the gpu memory
        torch.cuda.empty_cache()

    save_dir_4bit = os.path.join(
        save_root, f"{precision}-qencoder" if use_qencoder else f"{precision}" + ("-cpuoffload" if cpu_offload else "")
    )
    if not already_generate(save_dir_4bit, max_dataset_size):
        pipeline_init_kwargs = {}
        if precision == "int4":
            transformer = NunchakuFluxTransformer2dModel.from_pretrained(
                "mit-han-lab/svdq-int4-shuttle-jaguar", offload=cpu_offload
            )
        else:
            assert precision == "fp4"
            transformer = NunchakuFluxTransformer2dModel.from_pretrained(
                "mit-han-lab/svdq-fp4-shuttle-jaguar", precision="fp4", offload=cpu_offload
            )
        pipeline_init_kwargs["transformer"] = transformer
        if use_qencoder:
            raise NotImplementedError
            # text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
            # pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
        pipeline = FluxPipeline.from_pretrained(
            "shuttleai/shuttle-jaguar", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
        )
        pipeline = pipeline.to("cuda")
        if cpu_offload:
            pipeline.enable_sequential_cpu_offload()
        run_pipeline(
            dataset,
            pipeline,
            save_dir=save_dir_4bit,
            forward_kwargs={
                "height": height,
                "width": width,
                "num_inference_steps": num_inference_steps,
                "guidance_scale": guidance_scale,
            },
        )
        del pipeline
        # release the gpu memory
        torch.cuda.empty_cache()
    lpips = compute_lpips(save_dir_16bit, save_dir_4bit)
    print(f"lpips: {lpips}")
    assert lpips < expected_lpips * 1.05