test_flux_memory.py 1.56 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
import pytest
import torch
from diffusers import FluxPipeline

from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
6
from nunchaku.utils import get_precision, is_turing
muyangli's avatar
muyangli committed
7
8


9
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs")
muyangli's avatar
muyangli committed
10
11
12
13
14
15
16
17
18
19
@pytest.mark.parametrize(
    "use_qencoder,cpu_offload,memory_limit",
    [
        (False, False, 17),
        (False, True, 13),
        (True, False, 12),
        (True, True, 6),
    ],
)
def test_flux_schnell_memory(use_qencoder: bool, cpu_offload: bool, memory_limit: float):
20
    torch.cuda.empty_cache()
muyangli's avatar
muyangli committed
21
    torch.cuda.reset_peak_memory_stats()
22
    precision = get_precision()
muyangli's avatar
muyangli committed
23
24
    pipeline_init_kwargs = {
        "transformer": NunchakuFluxTransformer2dModel.from_pretrained(
25
            f"mit-han-lab/svdq-{precision}-flux.1-schnell", offload=cpu_offload
muyangli's avatar
muyangli committed
26
27
28
29
30
31
32
        )
    }
    if use_qencoder:
        text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
        pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
    pipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
33
    )
muyangli's avatar
muyangli committed
34
35
36

    if cpu_offload:
        pipeline.enable_sequential_cpu_offload()
37
38
    else:
        pipeline = pipeline.to("cuda")
muyangli's avatar
muyangli committed
39
40
41
42
43
44
45
46
47

    pipeline(
        "A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=50, guidance_scale=0
    )
    memory = torch.cuda.max_memory_reserved(0) / 1024**3
    assert memory < memory_limit
    del pipeline
    # release the gpu memory
    torch.cuda.empty_cache()