Unverified Commit f060b8da authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

[Major] Release v0.1.4

Support 4-bit text encoder and per-layer CPU offloading, reducing FLUX's minimum memory requirement to just 4 GiB while maintaining a 2–3× speedup. Fix various issues related to resolution, LoRA, pin memory, and runtime stability. Check out the release notes for full details!
parents f549dfc6 873a35be
import torch
from diffusers import SanaPAGPipeline, SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel
def test_sana():
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m")
pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer,
variant="bf16",
torch_dtype=torch.bfloat16,
).to("cuda")
pipe.vae.to(torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
prompt = "A cute 🐼 eating 🎋, ink drawing style"
image = pipe(
prompt=prompt,
height=1024,
width=1024,
guidance_scale=4.5,
num_inference_steps=20,
generator=torch.Generator().manual_seed(42),
).images[0]
image.save("sana_1600m.png")
def test_sana_pag():
transformer = NunchakuSanaTransformer2DModel.from_pretrained("mit-han-lab/svdq-int4-sana-1600m", pag_layers=8)
pipe = SanaPAGPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
transformer=transformer,
variant="bf16",
torch_dtype=torch.bfloat16,
pag_applied_layers="transformer_blocks.8",
).to("cuda")
pipe._set_pag_attn_processor = lambda *args, **kwargs: None
pipe.text_encoder.to(torch.bfloat16)
pipe.vae.to(torch.bfloat16)
image = pipe(
prompt="A cute 🐼 eating 🎋, ink drawing style",
height=1024,
width=1024,
guidance_scale=5.0,
pag_scale=2.0,
num_inference_steps=20,
).images[0]
image.save("sana_1600m_pag.png")
import os
import datasets
import torch
import torchvision
from PIL import Image
from torch.utils import data
from torchmetrics.image import LearnedPerceptualImagePatchSimilarity
from tqdm import tqdm
def hash_str_to_int(s: str) -> int:
"""Hash a string to an integer."""
modulus = 10**9 + 7 # Large prime modulus
hash_int = 0
for char in s:
hash_int = (hash_int * 31 + ord(char)) % modulus
return hash_int
def already_generate(save_dir: str, num_images) -> bool:
if os.path.exists(save_dir):
images = os.listdir(save_dir)
images = [_ for _ in images if _.endswith(".png")]
if len(images) == num_images:
return True
return False
class MultiImageDataset(data.Dataset):
def __init__(self, gen_dirpath: str, ref_dirpath: str | datasets.Dataset):
super(data.Dataset, self).__init__()
self.gen_names = sorted(
[name for name in os.listdir(gen_dirpath) if name.endswith(".png") or name.endswith(".jpg")]
)
self.ref_names = sorted(
[name for name in os.listdir(ref_dirpath) if name.endswith(".png") or name.endswith(".jpg")]
)
self.gen_dirpath, self.ref_dirpath = gen_dirpath, ref_dirpath
assert len(self.ref_names) == len(self.gen_names)
self.transform = torchvision.transforms.ToTensor()
def __len__(self):
return len(self.ref_names)
def __getitem__(self, idx: int):
name = self.ref_names[idx]
assert name == self.gen_names[idx]
ref_image = Image.open(os.path.join(self.ref_dirpath, name)).convert("RGB")
gen_image = Image.open(os.path.join(self.gen_dirpath, name)).convert("RGB")
gen_size = gen_image.size
ref_size = ref_image.size
if ref_size != gen_size:
ref_image = ref_image.resize(gen_size, Image.Resampling.BICUBIC)
gen_tensor = self.transform(gen_image)
ref_tensor = self.transform(ref_image)
return [gen_tensor, ref_tensor]
def compute_lpips(
ref_dirpath: str, gen_dirpath: str, batch_size: int = 64, num_workers: int = 8, device: str | torch.device = "cuda"
) -> float:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
metric = LearnedPerceptualImagePatchSimilarity(normalize=True).to(device)
dataset = MultiImageDataset(gen_dirpath, ref_dirpath)
dataloader = data.DataLoader(
dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False
)
with torch.no_grad():
desc = (os.path.basename(gen_dirpath)) + " LPIPS"
for i, batch in enumerate(tqdm(dataloader, desc=desc)):
batch = [tensor.to(device) for tensor in batch]
metric.update(batch[0], batch[1])
return metric.compute().item()
Subproject commit 0d23f715690c5171fd93679de8afd149376db167
Subproject commit a75b4ac483166189a45290783cb0a18af5ff0ea5
Subproject commit 63258397761b3dd96dd171e5a5ad5aa915834c35
Subproject commit 8b6b7d878c89e81614d05edca7936de41ccdd2da
Subproject commit 27cb4c76708608465c413f6d0e6b8d99a4d84302
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment