Commit 6b329882 authored by muyangli's avatar muyangli
Browse files

[major] Optimize RAM usage when loading the model; Update the demo

parent df981d24
......@@ -14,7 +14,7 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
*MIT, NVIDIA, CMU, Princeton, UC Berkeley, SJTU, and Pika Labs* <br>
<p align="center">
<img src="assets/demo.gif" width="70%"/>
<img src="assets/demo.gif" width="100%"/>
</p>
## Method
......
......@@ -20,5 +20,10 @@ h1{text-align:center}
width: 400px;
}
#accessibility {
text-align: center; /* Center-aligns the text */
margin: auto; /* Centers the element horizontally */
}
#random_seed {height: 71px;}
#run_button {height: 87px;}
\ No newline at end of file
......@@ -2,14 +2,16 @@ import os
from typing import Any, Callable, Optional, Union
import torch
import torchvision.transforms.functional as F
import torchvision.utils
from diffusers import __version__
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, FluxPipelineOutput, FluxTransformer2DModel
from einops import rearrange
from huggingface_hub import hf_hub_download, snapshot_download
from peft.tuners import lora
from PIL import Image
from safetensors.torch import load_file
from torch import nn
from torchvision.transforms import functional as F
from nunchaku.models.flux import inject_pipeline, load_quantized_model
from nunchaku.pipelines.flux import quantize_t5
......@@ -223,13 +225,42 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
qmodel_path = kwargs.pop("qmodel_path", None)
qencoder_path = kwargs.pop("qencoder_path", None)
pipeline = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
pipeline.precision = "bf16"
if qmodel_path is not None:
if qmodel_path is None:
pipeline = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
pipeline.precision = "bf16"
else:
assert kwargs.pop("transformer", None) is None
assert isinstance(qmodel_path, str)
if not os.path.exists(qmodel_path):
qmodel_path = snapshot_download(qmodel_path)
config, unused_kwargs, commit_hash = FluxTransformer2DModel.load_config(
pretrained_model_name_or_path,
subfolder="transformer",
cache_dir=kwargs.get("cache_dir", None),
return_unused_kwargs=True,
return_commit_hash=True,
force_download=kwargs.get("force_download", False),
proxies=kwargs.get("proxies", None),
local_files_only=kwargs.get("local_files_only", None),
token=kwargs.get("token", None),
revision=kwargs.get("revision", None),
user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"},
**kwargs,
)
new_config = {k: v for k, v in config.items()}
new_config.update({"num_layers": 0, "num_single_layers": 0})
transformer: nn.Module = FluxTransformer2DModel.from_config(new_config).to(
kwargs.get("torch_dtype", torch.bfloat16)
)
state_dict = load_file(os.path.join(qmodel_path, "unquantized_layers.safetensors"))
transformer.load_state_dict(state_dict, strict=False)
pipeline = super().from_pretrained(pretrained_model_name_or_path, transformer=transformer, **kwargs)
m = load_quantized_model(
os.path.join(qmodel_path, "transformer_blocks.safetensors"),
0 if qmodel_device.index is None else qmodel_device.index,
......@@ -237,6 +268,9 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
inject_pipeline(pipeline, m, qmodel_device)
pipeline.precision = "int4"
transformer.config["num_layers"] = config["num_layers"]
transformer.config["num_single_layers"] = config["num_single_layers"]
if qencoder_path is not None:
assert isinstance(qencoder_path, str)
if not os.path.exists(qencoder_path):
......
......@@ -205,6 +205,8 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
download_result.click(fn=save_image, inputs=result, outputs=download_result)
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
if __name__ == "__main__":
demo.queue().launch(debug=True, share=True)
h1{text-align:center}
h2{text-align:center}
#random_seed {height: 72px;}
\ No newline at end of file
#random_seed {height: 72px;}
#accessibility {
text-align: center; /* Center-aligns the text */
margin: auto; /* Centers the element horizontally */
}
\ No newline at end of file
......@@ -247,6 +247,8 @@ with gr.Blocks(
).then(
fn=generate_func, inputs=input_args, outputs=[*image_results, *latency_results], api_name=False, queue=False
)
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
if __name__ == "__main__":
demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True)
This image diff could not be displayed because it is too large. You can view the blob instead.
......@@ -5,7 +5,7 @@ from nunchaku.pipelines import flux as nunchaku_flux
pipeline = nunchaku_flux.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdquant-models/svdq-int4-flux.1-schnell.safetensors", # download from Huggingface
qmodel_path="mit-han-lab/svdq-int4-flux.1-schnell", # download from Huggingface
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0]
image.save("example.png")
__version__ = "0.0.0beta0"
\ No newline at end of file
__version__ = "0.0.1beta1"
import os
import torch
from diffusers import __version__
from diffusers import FluxPipeline, FluxTransformer2DModel
from diffusers import __version__, FluxPipeline, FluxTransformer2DModel
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from torch import nn
......@@ -54,6 +53,7 @@ def from_pretrained(pretrained_model_name_or_path: str | os.PathLike, **kwargs)
config, unused_kwargs, commit_hash = FluxTransformer2DModel.load_config(
pretrained_model_name_or_path,
subfolder="transformer",
cache_dir=kwargs.get("cache_dir", None),
return_unused_kwargs=True,
return_commit_hash=True,
......@@ -62,11 +62,17 @@ def from_pretrained(pretrained_model_name_or_path: str | os.PathLike, **kwargs)
local_files_only=kwargs.get("local_files_only", None),
token=kwargs.get("token", None),
revision=kwargs.get("revision", None),
subfolder="transformer",
user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"},
**kwargs,
)
transformer: nn.Module = FluxTransformer2DModel.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
new_config = {k: v for k, v in config.items()}
new_config.update({"num_layers": 0, "num_single_layers": 0})
transformer: nn.Module = FluxTransformer2DModel.from_config(new_config).to(
kwargs.get("torch_dtype", torch.bfloat16)
)
state_dict = load_file(os.path.join(qmodel_path, "unquantized_layers.safetensors"))
transformer.load_state_dict(state_dict, strict=False)
......@@ -77,6 +83,9 @@ def from_pretrained(pretrained_model_name_or_path: str | os.PathLike, **kwargs)
)
inject_pipeline(pipeline, m, qmodel_device)
transformer.config["num_layers"] = config["num_layers"]
transformer.config["num_single_layers"] = config["num_single_layers"]
if qencoder_path is not None:
assert isinstance(qencoder_path, str)
if not os.path.exists(qencoder_path):
......
......@@ -17,4 +17,5 @@ dependencies = [
"accelerate",
"sentencepiece",
"protobuf",
"huggingface_hub",
]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment