Commit 698bc83a authored by muyangli's avatar muyangli
Browse files

update the text_flux_tools

parent 84df0933
......@@ -5,11 +5,68 @@ from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_canny_dev():
run_test(
precision=get_precision(),
model_name="flux.1-canny-dev",
dataset_name="MJHQ-control",
task="canny",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=30,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.076 if get_precision() == "int4" else 0.164,
)
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_depth_dev():
run_test(
precision=get_precision(),
model_name="flux.1-depth-dev",
dataset_name="MJHQ-control",
task="depth",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=10,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.137 if get_precision() == "int4" else 0.120,
)
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_fill_dev():
run_test(
precision=get_precision(),
model_name="flux.1-fill-dev",
dataset_name="MJHQ-control",
task="fill",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=30,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
expected_lpips=0.046,
)
# @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
# def test_flux_canny_dev():
# def test_flux_dev_canny_lora():
# run_test(
# precision=get_precision(),
# model_name="flux.1-canny-dev",
# model_name="flux.1-dev",
# dataset_name="MJHQ-control",
# task="canny",
# dtype=torch.bfloat16,
......@@ -19,112 +76,55 @@ from .utils import run_test
# guidance_scale=30,
# attention_impl="nunchaku-fp16",
# cpu_offload=False,
# cache_threshold=0,
# expected_lpips=0.076 if get_precision() == "int4" else 0.164,
# )
#
#
# @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
# def test_flux_depth_dev():
# run_test(
# precision=get_precision(),
# model_name="flux.1-depth-dev",
# dataset_name="MJHQ-control",
# task="depth",
# dtype=torch.bfloat16,
# height=1024,
# width=1024,
# num_inference_steps=30,
# guidance_scale=10,
# attention_impl="nunchaku-fp16",
# cpu_offload=False,
# cache_threshold=0,
# expected_lpips=0.137 if get_precision() == "int4" else 0.120,
# )
#
#
# @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
# def test_flux_fill_dev():
# run_test(
# precision=get_precision(),
# model_name="flux.1-fill-dev",
# dataset_name="MJHQ-control",
# task="fill",
# dtype=torch.bfloat16,
# height=1024,
# width=1024,
# num_inference_steps=30,
# guidance_scale=30,
# attention_impl="nunchaku-fp16",
# cpu_offload=False,
# cache_threshold=0,
# expected_lpips=0.046,
# )
#
#
# # @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
# # def test_flux_dev_canny_lora():
# # run_test(
# # precision=get_precision(),
# # model_name="flux.1-dev",
# # dataset_name="MJHQ-control",
# # task="canny",
# # dtype=torch.bfloat16,
# # height=1024,
# # width=1024,
# # num_inference_steps=30,
# # guidance_scale=30,
# # attention_impl="nunchaku-fp16",
# # cpu_offload=False,
# # lora_names="canny",
# # lora_strengths=0.85,
# # cache_threshold=0,
# # expected_lpips=0.081,
# # )
#
#
# @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
# def test_flux_dev_depth_lora():
# run_test(
# precision=get_precision(),
# model_name="flux.1-dev",
# dataset_name="MJHQ-control",
# task="depth",
# dtype=torch.bfloat16,
# height=1024,
# width=1024,
# num_inference_steps=30,
# guidance_scale=10,
# attention_impl="nunchaku-fp16",
# cpu_offload=False,
# cache_threshold=0,
# lora_names="depth",
# lora_names="canny",
# lora_strengths=0.85,
# expected_lpips=0.181,
# )
#
#
# @pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
# def test_flux_fill_dev_turbo():
# run_test(
# precision=get_precision(),
# model_name="flux.1-fill-dev",
# dataset_name="MJHQ-control",
# task="fill",
# dtype=torch.bfloat16,
# height=1024,
# width=1024,
# num_inference_steps=8,
# guidance_scale=30,
# attention_impl="nunchaku-fp16",
# cpu_offload=False,
# cache_threshold=0,
# lora_names="turbo8",
# lora_strengths=1,
# expected_lpips=0.036,
# expected_lpips=0.081,
# )
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_depth_lora():
run_test(
precision=get_precision(),
model_name="flux.1-dev",
dataset_name="MJHQ-control",
task="depth",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=10,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
lora_names="depth",
lora_strengths=0.85,
expected_lpips=0.181,
)
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_fill_dev_turbo():
run_test(
precision=get_precision(),
model_name="flux.1-fill-dev",
dataset_name="MJHQ-control",
task="fill",
dtype=torch.bfloat16,
height=1024,
width=1024,
num_inference_steps=8,
guidance_scale=30,
attention_impl="nunchaku-fp16",
cpu_offload=False,
cache_threshold=0,
lora_names="turbo8",
lora_strengths=1,
expected_lpips=0.036,
)
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_redux():
run_test(
......@@ -141,5 +141,4 @@ def test_flux_dev_redux():
cpu_offload=False,
cache_threshold=0,
expected_lpips=(0.162 if get_precision() == "int4" else 0.5), # not sure why the fp4 model is so different
max_dataset_size=16,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment