test_device_id.py 1020 Bytes
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
import pytest
import torch
from diffusers import FluxPipeline

from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision, is_turing


@pytest.mark.skipif(
    is_turing() or torch.cuda.device_count() <= 1, reason="Skip tests due to using Turing GPUs or single GPU"
)
def test_device_id():
    precision = get_precision()  # auto-detect your precision is 'int4' or 'fp4' based on your GPU
muyangli's avatar
update  
muyangli committed
14
    torch_dtype = torch.float16 if is_turing("cuda:1") else torch.bfloat16
muyangli's avatar
muyangli committed
15
    transformer = NunchakuFluxTransformer2dModel.from_pretrained(
16
17
18
        f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors",
        torch_dtype=torch_dtype,
        device="cuda:1",
muyangli's avatar
muyangli committed
19
20
21
22
23
24
25
    )
    pipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch_dtype
    ).to("cuda:1")
    pipeline(
        "A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
    )