Commit 2ede5f01 authored by Muyang Li's avatar Muyang Li Committed by Zhekai Zhang
Browse files

Clean some codes and refract the tests

parent 83b7542d
# additional requirements for testing
pytest pytest
datasets datasets
torchmetrics torchmetrics
......
import pytest
import torch import torch
from diffusers import SanaPAGPipeline, SanaPipeline from diffusers import SanaPAGPipeline, SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel from nunchaku import NunchakuSanaTransformer2DModel
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing() or get_precision() == "fp4", reason="Skip tests due to Turing GPUs")
def test_sana(): def test_sana():
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m") transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
pipe = SanaPipeline.from_pretrained( pipe = SanaPipeline.from_pretrained(
...@@ -28,6 +31,7 @@ def test_sana(): ...@@ -28,6 +31,7 @@ def test_sana():
image.save("sana_1600m.png") image.save("sana_1600m.png")
@pytest.mark.skipif(is_turing() or get_precision() == "fp4", reason="Skip tests due to Turing GPUs")
def test_sana_pag(): def test_sana_pag():
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8) transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8)
pipe = SanaPAGPipeline.from_pretrained( pipe = SanaPAGPipeline.from_pretrained(
......
...@@ -59,7 +59,7 @@ class MultiImageDataset(data.Dataset): ...@@ -59,7 +59,7 @@ class MultiImageDataset(data.Dataset):
def compute_lpips( def compute_lpips(
ref_dirpath: str, gen_dirpath: str, batch_size: int = 64, num_workers: int = 8, device: str | torch.device = "cuda" ref_dirpath: str, gen_dirpath: str, batch_size: int = 4, num_workers: int = 8, device: str | torch.device = "cuda"
) -> float: ) -> float:
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
metric = LearnedPerceptualImagePatchSimilarity(normalize=True).to(device) metric = LearnedPerceptualImagePatchSimilarity(normalize=True).to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment