test_flux_memory.py 1.66 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
import pytest
import torch
from diffusers import FluxPipeline

from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
6
from nunchaku.utils import get_precision, is_turing
muyangli's avatar
muyangli committed
7
8


muyangli's avatar
muyangli committed
9
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
muyangli's avatar
muyangli committed
10
11
12
13
14
15
16
17
18
19
@pytest.mark.parametrize(
    "use_qencoder,cpu_offload,memory_limit",
    [
        (False, False, 17),
        (False, True, 13),
        (True, False, 12),
        (True, True, 6),
    ],
)
def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit: float):
20
    torch.cuda.empty_cache()
muyangli's avatar
muyangli committed
21
    torch.cuda.reset_peak_memory_stats()
22
    precision = get_precision()
muyangli's avatar
muyangli committed
23
24
    pipeline_init_kwargs = {
        "transformer": NunchakuFluxTransformer2dModel.from_pretrained(
25
            f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors", offload=cpu_offload
muyangli's avatar
muyangli committed
26
27
28
        )
    }
    if use_qencoder:
29
30
31
        text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
            "mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
        )
muyangli's avatar
muyangli committed
32
33
34
        pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
    pipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
35
    )
muyangli's avatar
muyangli committed
36
37
38

    if cpu_offload:
        pipeline.enable_sequential_cpu_offload()
39
40
    else:
        pipeline = pipeline.to("cuda")
muyangli's avatar
muyangli committed
41
42

    pipeline(
43
        "A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
muyangli's avatar
muyangli committed
44
45
46
47
48
49
    )
    memory = torch.cuda.max_memory_reserved(0) / 1024**3
    assert memory < memory_limit
    del pipeline
    # release the gpu memory
    torch.cuda.empty_cache()