Commit 2ede5f01 authored by Muyang Li's avatar Muyang Li Committed by Zhekai Zhang
Browse files

Clean some codes and refract the tests

parent 83b7542d
# additional requirements for testing
pytest
datasets
torchmetrics
......
import pytest
import torch
from diffusers import SanaPAGPipeline, SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing() or get_precision() == "fp4", reason="Skip tests due to Turing GPUs")
def test_sana():
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
pipe = SanaPipeline.from_pretrained(
......@@ -28,6 +31,7 @@ def test_sana():
image.save("sana_1600m.png")
@pytest.mark.skipif(is_turing() or get_precision() == "fp4", reason="Skip tests due to Turing GPUs")
def test_sana_pag():
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8)
pipe = SanaPAGPipeline.from_pretrained(
......
......@@ -59,7 +59,7 @@ class MultiImageDataset(data.Dataset):
def compute_lpips(
ref_dirpath: str, gen_dirpath: str, batch_size: int = 64, num_workers: int = 8, device: str | torch.device = "cuda"
ref_dirpath: str, gen_dirpath: str, batch_size: int = 4, num_workers: int = 8, device: str | torch.device = "cuda"
) -> float:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
metric = LearnedPerceptualImagePatchSimilarity(normalize=True).to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment