Commit 30ba84c5 authored by muyangli's avatar muyangli
Browse files

update tests

parent 748be0ab
import pytest
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(
is_turing() or torch.cuda.device_count() <= 1, reason="Skip tests due to using Turing GPUs or single GPU"
)
def test_device_id():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
torch_dtype = torch.float16 if is_turing("cuda:1") else torch.float32
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", torch_dtype=torch_dtype, device="cuda:1"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch_dtype
).to("cuda:1")
pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
)
...@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing ...@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests for Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips", "cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[ [
......
...@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing ...@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests for Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips", "height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips",
[ [
......
...@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing ...@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips", "num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips",
[ [
...@@ -35,7 +35,7 @@ def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offlo ...@@ -35,7 +35,7 @@ def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offlo
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_hypersd8_1536x2048(): def test_flux_dev_hypersd8_1536x2048():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -55,7 +55,7 @@ def test_flux_dev_hypersd8_1536x2048(): ...@@ -55,7 +55,7 @@ def test_flux_dev_hypersd8_1536x2048():
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_turbo8_1024x1920(): def test_flux_dev_turbo8_1024x1920():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -76,7 +76,7 @@ def test_flux_dev_turbo8_1024x1920(): ...@@ -76,7 +76,7 @@ def test_flux_dev_turbo8_1024x1920():
# lora composition # lora composition
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_turbo8_yarn_2048x1024(): def test_flux_dev_turbo8_yarn_2048x1024():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -96,7 +96,7 @@ def test_flux_dev_turbo8_yarn_2048x1024(): ...@@ -96,7 +96,7 @@ def test_flux_dev_turbo8_yarn_2048x1024():
# large rank loras # large rank loras
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_turbo8_yarn_1024x1024(): def test_flux_dev_turbo8_yarn_1024x1024():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
......
...@@ -6,7 +6,7 @@ from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel ...@@ -6,7 +6,7 @@ from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.utils import get_precision, is_turing from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_qencoder,cpu_offload,memory_limit", "use_qencoder,cpu_offload,memory_limit",
[ [
......
...@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing ...@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips", "height,width,attention_impl,cpu_offload,expected_lpips",
[ [
......
...@@ -5,7 +5,7 @@ from nunchaku.utils import get_precision, is_turing ...@@ -5,7 +5,7 @@ from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_canny_dev(): def test_flux_canny_dev():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -24,7 +24,7 @@ def test_flux_canny_dev(): ...@@ -24,7 +24,7 @@ def test_flux_canny_dev():
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_depth_dev(): def test_flux_depth_dev():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -43,7 +43,7 @@ def test_flux_depth_dev(): ...@@ -43,7 +43,7 @@ def test_flux_depth_dev():
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_fill_dev(): def test_flux_fill_dev():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -62,7 +62,7 @@ def test_flux_fill_dev(): ...@@ -62,7 +62,7 @@ def test_flux_fill_dev():
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_canny_lora(): def test_flux_dev_canny_lora():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -83,7 +83,7 @@ def test_flux_dev_canny_lora(): ...@@ -83,7 +83,7 @@ def test_flux_dev_canny_lora():
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_depth_lora(): def test_flux_dev_depth_lora():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -104,7 +104,7 @@ def test_flux_dev_depth_lora(): ...@@ -104,7 +104,7 @@ def test_flux_dev_depth_lora():
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_fill_dev_turbo(): def test_flux_fill_dev_turbo():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -125,7 +125,7 @@ def test_flux_fill_dev_turbo(): ...@@ -125,7 +125,7 @@ def test_flux_fill_dev_turbo():
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_redux(): def test_flux_dev_redux():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
......
...@@ -4,7 +4,7 @@ from .utils import run_test ...@@ -4,7 +4,7 @@ from .utils import run_test
from nunchaku.utils import get_precision, is_turing from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips", [(1024, 1024, "nunchaku-fp16", False, 0.25)] "height,width,attention_impl,cpu_offload,expected_lpips", [(1024, 1024, "nunchaku-fp16", False, 0.25)]
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment