Commit 6b329882 authored by muyangli's avatar muyangli
Browse files

[major] Optimize RAM usage when loading the model; Update the demo

parent df981d24
...@@ -14,7 +14,7 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat ...@@ -14,7 +14,7 @@ SVDQuant is a post-training quantization technique for 4-bit weights and activat
*MIT, NVIDIA, CMU, Princeton, UC Berkeley, SJTU, and Pika Labs* <br> *MIT, NVIDIA, CMU, Princeton, UC Berkeley, SJTU, and Pika Labs* <br>
<p align="center"> <p align="center">
<img src="assets/demo.gif" width="70%"/> <img src="assets/demo.gif" width="100%"/>
</p> </p>
## Method ## Method
......
...@@ -20,5 +20,10 @@ h1{text-align:center} ...@@ -20,5 +20,10 @@ h1{text-align:center}
width: 400px; width: 400px;
} }
#accessibility {
text-align: center; /* Center-aligns the text */
margin: auto; /* Centers the element horizontally */
}
#random_seed {height: 71px;} #random_seed {height: 71px;}
#run_button {height: 87px;} #run_button {height: 87px;}
\ No newline at end of file
...@@ -2,14 +2,16 @@ import os ...@@ -2,14 +2,16 @@ import os
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
import torchvision.transforms.functional as F
import torchvision.utils import torchvision.utils
from diffusers import __version__
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, FluxPipelineOutput, FluxTransformer2DModel from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, FluxPipelineOutput, FluxTransformer2DModel
from einops import rearrange from einops import rearrange
from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub import hf_hub_download, snapshot_download
from peft.tuners import lora from peft.tuners import lora
from PIL import Image from PIL import Image
from safetensors.torch import load_file
from torch import nn from torch import nn
from torchvision.transforms import functional as F
from nunchaku.models.flux import inject_pipeline, load_quantized_model from nunchaku.models.flux import inject_pipeline, load_quantized_model
from nunchaku.pipelines.flux import quantize_t5 from nunchaku.pipelines.flux import quantize_t5
...@@ -223,13 +225,42 @@ class FluxPix2pixTurboPipeline(FluxPipeline): ...@@ -223,13 +225,42 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
qmodel_path = kwargs.pop("qmodel_path", None) qmodel_path = kwargs.pop("qmodel_path", None)
qencoder_path = kwargs.pop("qencoder_path", None) qencoder_path = kwargs.pop("qencoder_path", None)
pipeline = super().from_pretrained(pretrained_model_name_or_path, **kwargs) if qmodel_path is None:
pipeline.precision = "bf16" pipeline = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
pipeline.precision = "bf16"
if qmodel_path is not None: else:
assert kwargs.pop("transformer", None) is None
assert isinstance(qmodel_path, str) assert isinstance(qmodel_path, str)
if not os.path.exists(qmodel_path): if not os.path.exists(qmodel_path):
qmodel_path = snapshot_download(qmodel_path) qmodel_path = snapshot_download(qmodel_path)
config, unused_kwargs, commit_hash = FluxTransformer2DModel.load_config(
pretrained_model_name_or_path,
subfolder="transformer",
cache_dir=kwargs.get("cache_dir", None),
return_unused_kwargs=True,
return_commit_hash=True,
force_download=kwargs.get("force_download", False),
proxies=kwargs.get("proxies", None),
local_files_only=kwargs.get("local_files_only", None),
token=kwargs.get("token", None),
revision=kwargs.get("revision", None),
user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"},
**kwargs,
)
new_config = {k: v for k, v in config.items()}
new_config.update({"num_layers": 0, "num_single_layers": 0})
transformer: nn.Module = FluxTransformer2DModel.from_config(new_config).to(
kwargs.get("torch_dtype", torch.bfloat16)
)
state_dict = load_file(os.path.join(qmodel_path, "unquantized_layers.safetensors"))
transformer.load_state_dict(state_dict, strict=False)
pipeline = super().from_pretrained(pretrained_model_name_or_path, transformer=transformer, **kwargs)
m = load_quantized_model( m = load_quantized_model(
os.path.join(qmodel_path, "transformer_blocks.safetensors"), os.path.join(qmodel_path, "transformer_blocks.safetensors"),
0 if qmodel_device.index is None else qmodel_device.index, 0 if qmodel_device.index is None else qmodel_device.index,
...@@ -237,6 +268,9 @@ class FluxPix2pixTurboPipeline(FluxPipeline): ...@@ -237,6 +268,9 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
inject_pipeline(pipeline, m, qmodel_device) inject_pipeline(pipeline, m, qmodel_device)
pipeline.precision = "int4" pipeline.precision = "int4"
transformer.config["num_layers"] = config["num_layers"]
transformer.config["num_single_layers"] = config["num_single_layers"]
if qencoder_path is not None: if qencoder_path is not None:
assert isinstance(qencoder_path, str) assert isinstance(qencoder_path, str)
if not os.path.exists(qencoder_path): if not os.path.exists(qencoder_path):
......
...@@ -205,6 +205,8 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De ...@@ -205,6 +205,8 @@ with gr.Blocks(css_paths="assets/style.css", title=f"SVDQuant Sketch-to-Image De
download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch) download_sketch.click(fn=save_image, inputs=canvas, outputs=download_sketch)
download_result.click(fn=save_image, inputs=result, outputs=download_result) download_result.click(fn=save_image, inputs=result, outputs=download_result)
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
if __name__ == "__main__": if __name__ == "__main__":
demo.queue().launch(debug=True, share=True) demo.queue().launch(debug=True, share=True)
h1{text-align:center} h1{text-align:center}
h2{text-align:center} h2{text-align:center}
#random_seed {height: 72px;} #random_seed {height: 72px;}
\ No newline at end of file
#accessibility {
text-align: center; /* Center-aligns the text */
margin: auto; /* Centers the element horizontally */
}
\ No newline at end of file
...@@ -247,6 +247,8 @@ with gr.Blocks( ...@@ -247,6 +247,8 @@ with gr.Blocks(
).then( ).then(
fn=generate_func, inputs=input_args, outputs=[*image_results, *latency_results], api_name=False, queue=False fn=generate_func, inputs=input_args, outputs=[*image_results, *latency_results], api_name=False, queue=False
) )
gr.Markdown("MIT Accessibility: https://accessibility.mit.edu/", elem_id="accessibility")
if __name__ == "__main__": if __name__ == "__main__":
demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True) demo.queue(max_size=20).launch(server_name="0.0.0.0", debug=True, share=True)
This image diff could not be displayed because it is too large. You can view the blob instead.
...@@ -5,7 +5,7 @@ from nunchaku.pipelines import flux as nunchaku_flux ...@@ -5,7 +5,7 @@ from nunchaku.pipelines import flux as nunchaku_flux
pipeline = nunchaku_flux.from_pretrained( pipeline = nunchaku_flux.from_pretrained(
"black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdquant-models/svdq-int4-flux.1-schnell.safetensors", # download from Huggingface qmodel_path="mit-han-lab/svdq-int4-flux.1-schnell", # download from Huggingface
).to("cuda") ).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0] image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0]
image.save("example.png") image.save("example.png")
__version__ = "0.0.0beta0" __version__ = "0.0.1beta1"
\ No newline at end of file
import os import os
import torch import torch
from diffusers import __version__ from diffusers import __version__, FluxPipeline, FluxTransformer2DModel
from diffusers import FluxPipeline, FluxTransformer2DModel
from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file from safetensors.torch import load_file
from torch import nn from torch import nn
...@@ -54,6 +53,7 @@ def from_pretrained(pretrained_model_name_or_path: str | os.PathLike, **kwargs) ...@@ -54,6 +53,7 @@ def from_pretrained(pretrained_model_name_or_path: str | os.PathLike, **kwargs)
config, unused_kwargs, commit_hash = FluxTransformer2DModel.load_config( config, unused_kwargs, commit_hash = FluxTransformer2DModel.load_config(
pretrained_model_name_or_path, pretrained_model_name_or_path,
subfolder="transformer",
cache_dir=kwargs.get("cache_dir", None), cache_dir=kwargs.get("cache_dir", None),
return_unused_kwargs=True, return_unused_kwargs=True,
return_commit_hash=True, return_commit_hash=True,
...@@ -62,11 +62,17 @@ def from_pretrained(pretrained_model_name_or_path: str | os.PathLike, **kwargs) ...@@ -62,11 +62,17 @@ def from_pretrained(pretrained_model_name_or_path: str | os.PathLike, **kwargs)
local_files_only=kwargs.get("local_files_only", None), local_files_only=kwargs.get("local_files_only", None),
token=kwargs.get("token", None), token=kwargs.get("token", None),
revision=kwargs.get("revision", None), revision=kwargs.get("revision", None),
subfolder="transformer",
user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"}, user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"},
**kwargs, **kwargs,
) )
transformer: nn.Module = FluxTransformer2DModel.from_config(config).to(kwargs.get("torch_dtype", torch.bfloat16))
new_config = {k: v for k, v in config.items()}
new_config.update({"num_layers": 0, "num_single_layers": 0})
transformer: nn.Module = FluxTransformer2DModel.from_config(new_config).to(
kwargs.get("torch_dtype", torch.bfloat16)
)
state_dict = load_file(os.path.join(qmodel_path, "unquantized_layers.safetensors")) state_dict = load_file(os.path.join(qmodel_path, "unquantized_layers.safetensors"))
transformer.load_state_dict(state_dict, strict=False) transformer.load_state_dict(state_dict, strict=False)
...@@ -77,6 +83,9 @@ def from_pretrained(pretrained_model_name_or_path: str | os.PathLike, **kwargs) ...@@ -77,6 +83,9 @@ def from_pretrained(pretrained_model_name_or_path: str | os.PathLike, **kwargs)
) )
inject_pipeline(pipeline, m, qmodel_device) inject_pipeline(pipeline, m, qmodel_device)
transformer.config["num_layers"] = config["num_layers"]
transformer.config["num_single_layers"] = config["num_single_layers"]
if qencoder_path is not None: if qencoder_path is not None:
assert isinstance(qencoder_path, str) assert isinstance(qencoder_path, str)
if not os.path.exists(qencoder_path): if not os.path.exists(qencoder_path):
......
...@@ -17,4 +17,5 @@ dependencies = [ ...@@ -17,4 +17,5 @@ dependencies = [
"accelerate", "accelerate",
"sentencepiece", "sentencepiece",
"protobuf", "protobuf",
"huggingface_hub",
] ]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment