Commit 45e055ce authored by Muyang Li's avatar Muyang Li Committed by muyangli
Browse files

Merge pull request #86 from mit-han-lab/dev/muyang

[minor] fix the pix2pix-turbo demo
parent 48b2dc3c
from typing import Any, Callable from typing import Any, Callable
import torch import torch
import torchvision.utils
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, FluxPipelineOutput, FluxTransformer2DModel from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, FluxPipelineOutput, FluxTransformer2DModel
from einops import rearrange from einops import rearrange
from peft.tuners import lora from peft.tuners import lora
...@@ -9,6 +8,8 @@ from PIL import Image ...@@ -9,6 +8,8 @@ from PIL import Image
from torch import nn from torch import nn
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
from nunchaku.utils import load_state_dict_in_safetensors
class FluxPix2pixTurboPipeline(FluxPipeline): class FluxPix2pixTurboPipeline(FluxPipeline):
def update_alpha(self, alpha: float) -> None: def update_alpha(self, alpha: float) -> None:
...@@ -55,7 +56,9 @@ class FluxPix2pixTurboPipeline(FluxPipeline): ...@@ -55,7 +56,9 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
self.load_lora_into_transformer(state_dict, {}, transformer=transformer) self.load_lora_into_transformer(state_dict, {}, transformer=transformer)
else: else:
assert svdq_lora_path is not None assert svdq_lora_path is not None
self.transformer.update_lora_params(svdq_lora_path) sd = load_state_dict_in_safetensors(svdq_lora_path)
sd = {k: v for k, v in sd.items() if not k.startswith("transformer.")}
self.transformer.update_lora_params(sd)
self.update_alpha(alpha) self.update_alpha(alpha)
@torch.no_grad() @torch.no_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment