"vscode:/vscode.git/clone" did not exist on "61e4433cafd6115a22cf2143f7f38b157ab39b53"
Commit 30ba84c5 authored by muyangli's avatar muyangli
Browse files

update tests

parent 748be0ab
import pytest
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(
is_turing() or torch.cuda.device_count() <= 1, reason="Skip tests due to using Turing GPUs or single GPU"
)
def test_device_id():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
torch_dtype = torch.float16 if is_turing("cuda:1") else torch.float32
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-schnell", torch_dtype=torch_dtype, device="cuda:1"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch_dtype
).to("cuda:1")
pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
)
...@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing ...@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests for Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips", "cache_threshold,height,width,num_inference_steps,lora_name,lora_strength,expected_lpips",
[ [
......
...@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing ...@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests for Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips", "height,width,num_inference_steps,attention_impl,cpu_offload,expected_lpips",
[ [
......
...@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing ...@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips", "num_inference_steps,lora_name,lora_strength,cpu_offload,expected_lpips",
[ [
...@@ -35,7 +35,7 @@ def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offlo ...@@ -35,7 +35,7 @@ def test_flux_dev_loras(num_inference_steps, lora_name, lora_strength, cpu_offlo
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_hypersd8_1536x2048(): def test_flux_dev_hypersd8_1536x2048():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -55,7 +55,7 @@ def test_flux_dev_hypersd8_1536x2048(): ...@@ -55,7 +55,7 @@ def test_flux_dev_hypersd8_1536x2048():
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_turbo8_1024x1920(): def test_flux_dev_turbo8_1024x1920():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -76,7 +76,7 @@ def test_flux_dev_turbo8_1024x1920(): ...@@ -76,7 +76,7 @@ def test_flux_dev_turbo8_1024x1920():
# lora composition # lora composition
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_turbo8_yarn_2048x1024(): def test_flux_dev_turbo8_yarn_2048x1024():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -96,7 +96,7 @@ def test_flux_dev_turbo8_yarn_2048x1024(): ...@@ -96,7 +96,7 @@ def test_flux_dev_turbo8_yarn_2048x1024():
# large rank loras # large rank loras
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_turbo8_yarn_1024x1024(): def test_flux_dev_turbo8_yarn_1024x1024():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
......
...@@ -6,7 +6,7 @@ from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel ...@@ -6,7 +6,7 @@ from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.utils import get_precision, is_turing from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_qencoder,cpu_offload,memory_limit", "use_qencoder,cpu_offload,memory_limit",
[ [
......
...@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing ...@@ -4,7 +4,7 @@ from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips", "height,width,attention_impl,cpu_offload,expected_lpips",
[ [
......
...@@ -5,7 +5,7 @@ from nunchaku.utils import get_precision, is_turing ...@@ -5,7 +5,7 @@ from nunchaku.utils import get_precision, is_turing
from .utils import run_test from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_canny_dev(): def test_flux_canny_dev():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -24,7 +24,7 @@ def test_flux_canny_dev(): ...@@ -24,7 +24,7 @@ def test_flux_canny_dev():
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_depth_dev(): def test_flux_depth_dev():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -43,7 +43,7 @@ def test_flux_depth_dev(): ...@@ -43,7 +43,7 @@ def test_flux_depth_dev():
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_fill_dev(): def test_flux_fill_dev():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -62,7 +62,7 @@ def test_flux_fill_dev(): ...@@ -62,7 +62,7 @@ def test_flux_fill_dev():
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_canny_lora(): def test_flux_dev_canny_lora():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -83,7 +83,7 @@ def test_flux_dev_canny_lora(): ...@@ -83,7 +83,7 @@ def test_flux_dev_canny_lora():
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_depth_lora(): def test_flux_dev_depth_lora():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -104,7 +104,7 @@ def test_flux_dev_depth_lora(): ...@@ -104,7 +104,7 @@ def test_flux_dev_depth_lora():
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_fill_dev_turbo(): def test_flux_fill_dev_turbo():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
...@@ -125,7 +125,7 @@ def test_flux_fill_dev_turbo(): ...@@ -125,7 +125,7 @@ def test_flux_fill_dev_turbo():
) )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_redux(): def test_flux_dev_redux():
run_test( run_test(
precision=get_precision(), precision=get_precision(),
......
...@@ -4,7 +4,7 @@ from .utils import run_test ...@@ -4,7 +4,7 @@ from .utils import run_test
from nunchaku.utils import get_precision, is_turing from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to Turing GPUs") @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"height,width,attention_impl,cpu_offload,expected_lpips", [(1024, 1024, "nunchaku-fp16", False, 0.25)] "height,width,attention_impl,cpu_offload,expected_lpips", [(1024, 1024, "nunchaku-fp16", False, 0.25)]
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment