Unverified Commit 140b21e5 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

release: v0.3.0dev1 pre-release

release: v0.3.0dev1 pre-release
parents 2eedc2cb f828be33
...@@ -65,7 +65,11 @@ def save_image(img): ...@@ -65,7 +65,11 @@ def save_image(img):
def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: int) -> tuple[Image, str]: def run(image, prompt: str, prompt_template: str, sketch_guidance: float, seed: int) -> tuple[Image, str]:
print(f"Prompt: {prompt}") print(f"Prompt: {prompt}")
image_numpy = np.array(image["composite"].convert("RGB"))
if image["composite"] is None:
image_numpy = np.array(blank_image.convert("RGB"))
else:
image_numpy = np.array(image["composite"].convert("RGB"))
if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628): if prompt.strip() == "" and (np.sum(image_numpy == 255) >= 3145628 or np.sum(image_numpy == 0) >= 3145628):
return blank_image, "Please input the prompt or draw something." return blank_image, "Please input the prompt or draw something."
......
...@@ -22,7 +22,7 @@ pipeline.transformer.forward = MethodType(pulid_forward, pipeline.transformer) ...@@ -22,7 +22,7 @@ pipeline.transformer.forward = MethodType(pulid_forward, pipeline.transformer)
id_image = load_image("https://github.com/ToTheBeginning/PuLID/blob/main/example_inputs/liuyifei.png?raw=true") id_image = load_image("https://github.com/ToTheBeginning/PuLID/blob/main/example_inputs/liuyifei.png?raw=true")
image = pipeline( image = pipeline(
"A woman holding a sign that says 'SVDQuant is fast!", "A woman holding a sign that says 'SVDQuant is fast!'",
id_image=id_image, id_image=id_image,
id_weight=1, id_weight=1,
num_inference_steps=12, num_inference_steps=12,
......
...@@ -390,19 +390,18 @@ class FluxCachedTransformerBlocks(nn.Module): ...@@ -390,19 +390,18 @@ class FluxCachedTransformerBlocks(nn.Module):
original_dtype = hidden_states.dtype original_dtype = hidden_states.dtype
original_device = hidden_states.device original_device = hidden_states.device
hidden_states = hidden_states.to(self.dtype).to(self.device) hidden_states = hidden_states.to(self.dtype).to(original_device)
encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(self.device) encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(original_device)
temb = temb.to(self.dtype).to(self.device) temb = temb.to(self.dtype).to(original_device)
image_rotary_emb = image_rotary_emb.to(self.device) image_rotary_emb = image_rotary_emb.to(original_device)
if controlnet_block_samples is not None: if controlnet_block_samples is not None:
controlnet_block_samples = ( controlnet_block_samples = (
torch.stack(controlnet_block_samples).to(self.device) if len(controlnet_block_samples) > 0 else None torch.stack(controlnet_block_samples).to(original_device) if len(controlnet_block_samples) > 0 else None
) )
if controlnet_single_block_samples is not None and len(controlnet_single_block_samples) > 0: if controlnet_single_block_samples is not None and len(controlnet_single_block_samples) > 0:
controlnet_single_block_samples = ( controlnet_single_block_samples = (
torch.stack(controlnet_single_block_samples).to(self.device) torch.stack(controlnet_single_block_samples).to(original_device)
if len(controlnet_single_block_samples) > 0 if len(controlnet_single_block_samples) > 0
else None else None
) )
......
...@@ -136,4 +136,4 @@ if __name__ == "__main__": ...@@ -136,4 +136,4 @@ if __name__ == "__main__":
parser.add_argument("-o", "--output-path", type=str, required=True, help="path to the output safetensors file") parser.add_argument("-o", "--output-path", type=str, required=True, help="path to the output safetensors file")
args = parser.parse_args() args = parser.parse_args()
assert len(args.input_paths) == len(args.strengths) assert len(args.input_paths) == len(args.strengths)
composed = compose_lora(list(zip(args.input_paths, args.strengths))) compose_lora(list(zip(args.input_paths, args.strengths)), args.output_path)
...@@ -117,6 +117,9 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901 ...@@ -117,6 +117,9 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
if orig_lora[0] is None or orig_lora[1] is None: if orig_lora[0] is None or orig_lora[1] is None:
assert orig_lora[0] is None and orig_lora[1] is None assert orig_lora[0] is None and orig_lora[1] is None
orig_lora = None orig_lora = None
elif orig_lora[0].numel() == 0 or orig_lora[1].numel() == 0:
assert orig_lora[0].numel() == 0 and orig_lora[1].numel() == 0
orig_lora = None
else: else:
assert orig_lora[0] is not None and orig_lora[1] is not None assert orig_lora[0] is not None and orig_lora[1] is not None
orig_lora = ( orig_lora = (
......
...@@ -333,6 +333,17 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -333,6 +333,17 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
elif "qweight" in k: elif "qweight" in k:
# only the shape information of this tensor is needed # only the shape information of this tensor is needed
new_quantized_part_sd[k] = v.to("meta") new_quantized_part_sd[k] = v.to("meta")
# if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors
for t in ["lora_up", "lora_down"]:
new_k = k.replace(".qweight", f".{t}")
if new_k not in quantized_part_sd:
oc, ic = v.shape
ic = ic * 2 # v is packed into INT8, so we need to double the size
new_quantized_part_sd[k.replace(".qweight", f".{t}")] = torch.zeros(
(0, ic) if t == "lora_down" else (oc, 0), device=v.device, dtype=torch.bfloat16
)
elif "lora" in k: elif "lora" in k:
new_quantized_part_sd[k] = v new_quantized_part_sd[k] = v
transformer._quantized_part_sd = new_quantized_part_sd transformer._quantized_part_sd = new_quantized_part_sd
......
...@@ -11,7 +11,7 @@ from .utils import run_test ...@@ -11,7 +11,7 @@ from .utils import run_test
[ [
(1024, 1024, "flashattn2", False, 0.126 if get_precision() == "int4" else 0.126), (1024, 1024, "flashattn2", False, 0.126 if get_precision() == "int4" else 0.126),
(1024, 1024, "nunchaku-fp16", False, 0.126 if get_precision() == "int4" else 0.126), (1024, 1024, "nunchaku-fp16", False, 0.126 if get_precision() == "int4" else 0.126),
(1920, 1080, "nunchaku-fp16", False, 0.158 if get_precision() == "int4" else 0.138), (1920, 1080, "nunchaku-fp16", False, 0.190 if get_precision() == "int4" else 0.138),
(2048, 2048, "nunchaku-fp16", True, 0.166 if get_precision() == "int4" else 0.120), (2048, 2048, "nunchaku-fp16", True, 0.166 if get_precision() == "int4" else 0.120),
], ],
) )
......
...@@ -54,7 +54,7 @@ from .utils import already_generate, compute_lpips, offload_pipeline ...@@ -54,7 +54,7 @@ from .utils import already_generate, compute_lpips, offload_pipeline
"waterfall", "waterfall",
23, 23,
0.6, 0.6,
0.226 if get_precision() == "int4" else 0.226, 0.253 if get_precision() == "int4" else 0.226,
), ),
], ],
) )
......
import os
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision
from ..utils import compute_lpips
def test_lora_reset():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/svdq-{precision}-flux.1-dev", offload=True
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
)
pipeline.enable_sequential_cpu_offload()
save_dir = os.path.join("test_results", "bf16", "flux", "lora_reset")
os.makedirs(save_dir, exist_ok=True)
image = pipeline(
"cozy mountain cabin covered in snow, with smoke curling from the chimney and a warm, inviting light spilling through the windows", # noqa: E501
num_inference_steps=8,
guidance_scale=3.5,
generator=torch.Generator().manual_seed(23),
).images[0]
image.save(os.path.join(save_dir, "before.png"))
transformer.update_lora_params("alimama-creative/FLUX.1-Turbo-Alpha/diffusion_pytorch_model.safetensors")
transformer.set_lora_strength(50)
transformer.reset_lora()
image = pipeline(
"cozy mountain cabin covered in snow, with smoke curling from the chimney and a warm, inviting light spilling through the windows", # noqa: E501
num_inference_steps=8,
guidance_scale=3.5,
generator=torch.Generator().manual_seed(23),
).images[0]
image.save(os.path.join(save_dir, "after.png"))
lpips = compute_lpips(os.path.join(save_dir, "before.png"), os.path.join(save_dir, "after.png"))
print(f"LPIPS: {lpips}")
assert lpips < 0.158 * 1.1
...@@ -11,7 +11,7 @@ from .utils import run_test ...@@ -11,7 +11,7 @@ from .utils import run_test
"height,width,attention_impl,cpu_offload,expected_lpips,batch_size", "height,width,attention_impl,cpu_offload,expected_lpips,batch_size",
[ [
(1024, 1024, "nunchaku-fp16", False, 0.140 if get_precision() == "int4" else 0.135, 2), (1024, 1024, "nunchaku-fp16", False, 0.140 if get_precision() == "int4" else 0.135, 2),
(1920, 1080, "flashattn2", True, 0.160 if get_precision() == "int4" else 0.123, 4), (1920, 1080, "flashattn2", True, 0.177 if get_precision() == "int4" else 0.123, 4),
], ],
) )
def test_flux_schnell( def test_flux_schnell(
......
...@@ -28,15 +28,32 @@ def already_generate(save_dir: str, num_images) -> bool: ...@@ -28,15 +28,32 @@ def already_generate(save_dir: str, num_images) -> bool:
class MultiImageDataset(data.Dataset): class MultiImageDataset(data.Dataset):
def __init__(self, gen_dirpath: str, ref_dirpath: str | datasets.Dataset): def __init__(self, gen_dirpath_or_image_path: str, ref_dirpath_or_image_path: str | datasets.Dataset):
super(data.Dataset, self).__init__() super(data.Dataset, self).__init__()
self.gen_names = sorted( if os.path.isdir(gen_dirpath_or_image_path):
[name for name in os.listdir(gen_dirpath) if name.endswith(".png") or name.endswith(".jpg")] self.gen_names = sorted(
) [
self.ref_names = sorted( name
[name for name in os.listdir(ref_dirpath) if name.endswith(".png") or name.endswith(".jpg")] for name in os.listdir(gen_dirpath_or_image_path)
) if name.endswith(".png") or name.endswith(".jpg")
self.gen_dirpath, self.ref_dirpath = gen_dirpath, ref_dirpath ]
)
self.gen_dirpath = gen_dirpath_or_image_path
else:
self.gen_names = [os.path.basename(gen_dirpath_or_image_path)]
self.gen_dirpath = os.path.dirname(gen_dirpath_or_image_path)
if os.path.isdir(ref_dirpath_or_image_path):
self.ref_names = sorted(
[
name
for name in os.listdir(ref_dirpath_or_image_path)
if name.endswith(".png") or name.endswith(".jpg")
]
)
self.ref_dirpath = ref_dirpath_or_image_path
else:
self.ref_names = [os.path.basename(ref_dirpath_or_image_path)]
self.ref_dirpath = os.path.dirname(ref_dirpath_or_image_path)
assert len(self.ref_names) == len(self.gen_names) assert len(self.ref_names) == len(self.gen_names)
self.transform = torchvision.transforms.ToTensor() self.transform = torchvision.transforms.ToTensor()
...@@ -45,10 +62,8 @@ class MultiImageDataset(data.Dataset): ...@@ -45,10 +62,8 @@ class MultiImageDataset(data.Dataset):
return len(self.ref_names) return len(self.ref_names)
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
name = self.ref_names[idx] ref_image = Image.open(os.path.join(self.ref_dirpath, self.ref_names[idx])).convert("RGB")
assert name == self.gen_names[idx] gen_image = Image.open(os.path.join(self.gen_dirpath, self.gen_names[idx])).convert("RGB")
ref_image = Image.open(os.path.join(self.ref_dirpath, name)).convert("RGB")
gen_image = Image.open(os.path.join(self.gen_dirpath, name)).convert("RGB")
gen_size = gen_image.size gen_size = gen_image.size
ref_size = ref_image.size ref_size = ref_image.size
if ref_size != gen_size: if ref_size != gen_size:
...@@ -59,16 +74,20 @@ class MultiImageDataset(data.Dataset): ...@@ -59,16 +74,20 @@ class MultiImageDataset(data.Dataset):
def compute_lpips( def compute_lpips(
ref_dirpath: str, gen_dirpath: str, batch_size: int = 4, num_workers: int = 0, device: str | torch.device = "cuda" ref_dirpath_or_image_path: str,
gen_dirpath_or_image_path: str,
batch_size: int = 4,
num_workers: int = 0,
device: str | torch.device = "cuda",
) -> float: ) -> float:
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
metric = LearnedPerceptualImagePatchSimilarity(normalize=True).to(device) metric = LearnedPerceptualImagePatchSimilarity(normalize=True).to(device)
dataset = MultiImageDataset(gen_dirpath, ref_dirpath) dataset = MultiImageDataset(gen_dirpath_or_image_path, ref_dirpath_or_image_path)
dataloader = data.DataLoader( dataloader = data.DataLoader(
dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False
) )
with torch.no_grad(): with torch.no_grad():
desc = (os.path.basename(gen_dirpath)) + " LPIPS" desc = (os.path.basename(gen_dirpath_or_image_path)) + " LPIPS"
for i, batch in enumerate(tqdm(dataloader, desc=desc)): for i, batch in enumerate(tqdm(dataloader, desc=desc)):
batch = [tensor.to(device) for tensor in batch] batch = [tensor.to(device) for tensor in batch]
metric.update(batch[0], batch[1]) metric.update(batch[0], batch[1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment