Unverified Commit 37a27712 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Merge pull request #340 from mit-han-lab/dev

feat: support PuLID, Double FBCache and TeaCache; better linter
parents c1d6fc84 760ab022
...@@ -9,13 +9,13 @@ import GPUtil ...@@ -9,13 +9,13 @@ import GPUtil
import spaces import spaces
import torch import torch
from peft.tuners import lora from peft.tuners import lora
from nunchaku.models.safety_checker import SafetyChecker
from utils import get_pipeline from utils import get_pipeline
from vars import DEFAULT_HEIGHT, DEFAULT_WIDTH, EXAMPLES, MAX_SEED, PROMPT_TEMPLATES, SVDQ_LORA_PATHS from vars import DEFAULT_HEIGHT, DEFAULT_WIDTH, EXAMPLES, MAX_SEED, PROMPT_TEMPLATES, SVDQ_LORA_PATHS
from nunchaku.models.safety_checker import SafetyChecker
# import gradio last to avoid conflicts with other imports # import gradio last to avoid conflicts with other imports
import gradio as gr import gradio as gr # noqa: isort: skip
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
...@@ -84,7 +84,7 @@ def generate( ...@@ -84,7 +84,7 @@ def generate(
images, latency_strs = [], [] images, latency_strs = [], []
for i, pipeline in enumerate(pipelines): for i, pipeline in enumerate(pipelines):
precision = args.precisions[i] precision = args.precisions[i]
progress = gr.Progress(track_tqdm=True) gr.Progress(track_tqdm=True)
if pipeline.cur_lora_name != lora_name: if pipeline.cur_lora_name != lora_name:
if precision == "bf16": if precision == "bf16":
for m in pipeline.transformer.modules(): for m in pipeline.transformer.modules():
...@@ -164,7 +164,7 @@ if len(gpus) > 0: ...@@ -164,7 +164,7 @@ if len(gpus) > 0:
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
with gr.Blocks( with gr.Blocks(
css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"], css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
......
import torch import torch
from diffusers import FluxPipeline from diffusers import FluxPipeline
from peft.tuners import lora from peft.tuners import lora
from vars import LORA_PATHS, SVDQ_LORA_PATHS
from nunchaku import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
from vars import LORA_PATHS, SVDQ_LORA_PATHS
def hash_str_to_int(s: str) -> int: def hash_str_to_int(s: str) -> int:
......
...@@ -37,4 +37,4 @@ python latency.py ...@@ -37,4 +37,4 @@ python latency.py
* Adjust the number of inference steps and the guidance scale using `-t` and `-g`, respectively. The defaults are 20 steps and a guidance scale of 5. * Adjust the number of inference steps and the guidance scale using `-t` and `-g`, respectively. The defaults are 20 steps and a guidance scale of 5.
* You can also adjust the [PAG guidance](https://arxiv.org/abs/2403.17377) scale with `--pag-scale`. The default is 2. * You can also adjust the [PAG guidance](https://arxiv.org/abs/2403.17377) scale with `--pag-scale`. The default is 2.
* By default, the script measures the end-to-end latency for generating a single image. To measure the latency of a single DiT forward step instead, use the `--mode step` flag. * By default, the script measures the end-to-end latency for generating a single image. To measure the latency of a single DiT forward step instead, use the `--mode step` flag.
* Specify the number of warmup and test runs using `--warmup-times` and `--test-times`. The defaults are 2 warmup runs and 10 test runs. * Specify the number of warmup and test runs using `--warmup-times` and `--test-times`. The defaults are 2 warmup runs and 10 test runs.
\ No newline at end of file
...@@ -6,4 +6,4 @@ h2{text-align:center} ...@@ -6,4 +6,4 @@ h2{text-align:center}
#accessibility { #accessibility {
text-align: center; /* Center-aligns the text */ text-align: center; /* Center-aligns the text */
margin: auto; /* Centers the element horizontally */ margin: auto; /* Centers the element horizontally */
} }
\ No newline at end of file
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<div> <div>
<h1> <h1>
<img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/logo.svg" <img src="https://github.com/mit-han-lab/nunchaku/raw/refs/heads/main/assets/svdquant.svg"
alt="logo" alt="logo"
style="height: 40px; width: auto; display: block; margin: auto;"/> style="height: 40px; width: auto; display: block; margin: auto;"/>
<a href='https://nvlabs.github.io/Sana/' target="_blank">SANA-1.6B</a> Demo <a href='https://nvlabs.github.io/Sana/' target="_blank">SANA-1.6B</a> Demo
...@@ -50,4 +50,4 @@ ...@@ -50,4 +50,4 @@
</div> </div>
{count_info} {count_info}
</div> </div>
</div> </div>
\ No newline at end of file
...@@ -2,7 +2,6 @@ import argparse ...@@ -2,7 +2,6 @@ import argparse
import os import os
import torch import torch
from utils import get_pipeline from utils import get_pipeline
......
...@@ -4,7 +4,6 @@ import time ...@@ -4,7 +4,6 @@ import time
import torch import torch
from torch import nn from torch import nn
from tqdm import trange from tqdm import trange
from utils import get_pipeline from utils import get_pipeline
......
...@@ -8,13 +8,13 @@ from datetime import datetime ...@@ -8,13 +8,13 @@ from datetime import datetime
import GPUtil import GPUtil
import spaces import spaces
import torch import torch
from nunchaku.models.safety_checker import SafetyChecker
from utils import get_pipeline from utils import get_pipeline
from vars import EXAMPLES, MAX_SEED from vars import EXAMPLES, MAX_SEED
from nunchaku.models.safety_checker import SafetyChecker
# import gradio last to avoid conflicts with other imports # import gradio last to avoid conflicts with other imports
import gradio as gr import gradio as gr # noqa: isort: skip
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
...@@ -73,7 +73,7 @@ def generate( ...@@ -73,7 +73,7 @@ def generate(
prompt = "A peaceful world." prompt = "A peaceful world."
images, latency_strs = [], [] images, latency_strs = [], []
for i, pipeline in enumerate(pipelines): for i, pipeline in enumerate(pipelines):
progress = gr.Progress(track_tqdm=True) gr.Progress(track_tqdm=True)
start_time = time.time() start_time = time.time()
image = pipeline( image = pipeline(
prompt=prompt, prompt=prompt,
...@@ -124,11 +124,11 @@ if len(gpus) > 0: ...@@ -124,11 +124,11 @@ if len(gpus) > 0:
device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory." device_info = f"Running on {gpu.name} with {memory:.0f} GiB memory."
else: else:
device_info = "Running on CPU 🥶 This demo does not work on CPU." device_info = "Running on CPU 🥶 This demo does not work on CPU."
notice = f'<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."' notice = '<strong>Notice:</strong>&nbsp;We will replace unsafe prompts with a default prompt: "A peaceful world."'
with gr.Blocks( with gr.Blocks(
css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"], css_paths=[f"assets/frame{len(args.precisions)}.css", "assets/common.css"],
title=f"SVDQuant SANA-1600M Demo", title="SVDQuant SANA-1600M Demo",
) as demo: ) as demo:
def get_header_str(): def get_header_str():
......
...@@ -46,4 +46,4 @@ ...@@ -46,4 +46,4 @@
</g> </g>
<path d="M418.39,22.56c-.9-2.12-3.08-3.99-2.86-6.3.6-6.24-1.96-9.26-5.87-10.8-5.59-2.76-10.79-2.48-15.59.89-5.16,3.63-6.9,8.92-5.88,15.06-3.44,1.79-6.77,3.46-10.03,5.27-1.04.58-1.67.45-2.57-.24-4.36-3.31-9.77-3.35-14.45-.38-2.92,1.85-5.92,3.61-8.99,5.2-4.67,2.41-8.51,5.37-9.23,11.06-.06.44-.81,1.01-1.34,1.15-2.64.72-5.32,1.29-7.97,1.98-1.09.28-1.8-.03-2.5-.87-3.33-4.01-7.59-5.28-12.62-4.14-3.55.8-7.1,1.63-10.65,2.41-4.53.99-8.9,2.23-11.5,6.61-.14.23-.76.32-1.12.26-3.14-.54-6.26-1.14-9.44-1.73-.4-4.66-2.91-7.77-6.66-10.13-3.81-2.39-7.54-4.92-11.29-7.41-2.5-1.65-5.47-2.9-8.14-1.91-3.92,1.46-5.66-.68-7.62-3.11-.53-.65-1.1-1.28-1.71-1.87-.91-.89-1.15-1.7-.63-3.04,2.56-6.58-1.25-14.13-8-16.06-4.78-1.36-9.57-2.67-14.37-3.94-6.58-1.74-12.14.91-14.99,7.05-.24.51-.79,1.18-1.25,1.23-1.63.18-3.26.33-4.89.46.01.52.01,1.04.01,1.56,4.44-1,8.77-1.17,13.19-.6-1.82,1.27-8.29,2.27-13.22,2.36-.04,1.47-.13,2.95-.23,4.43,4.6-.4,9.19-.79,13.79-1.19.01.08.02.15.03.23-2.2.7-4.39,1.39-6.62,2.09,1.3,2.68,3.69,4.83,6.67,5.69,5.33,1.55,10.69,3.06,16.09,4.37,1.72.42,3.61.13,5.84.18-1.34-2.39-2.39-4.26-3.44-6.13l.3-.23c5.72,6.3,11.43,12.61,17.15,18.91-.06.07-.12.13-.18.2-2.04-1.41-4.09-2.82-6.2-4.27-1.71,5.48.04,10.66,4.66,13.84,4.3,2.96,8.67,5.81,13.05,8.64,5.02,3.25,12.27,1.96,15.19-2.14-2.16-.92-4.3-1.83-6.44-2.74.05-.15.11-.3.16-.45,6.02,1.12,12.04,2.21,18.04,3.4.43.09.91.85,1.05,1.39,1.65,6.24,7.78,10.23,14.06,8.93,4.97-1.03,9.89-2.3,14.84-3.41,4.98-1.12,8.06-4.16,9.57-9.25-2.61.09-5,.18-7.4.27l-.02-.24,27-6.51c.05.15.09.31.14.46l-6.85,3.18c3.69,3.77,9.13,4.98,13.57,2.64,5.32-2.8,10.5-5.87,15.62-9.01,2.83-1.74,5.21-6.46,4.49-8.99-2.38.52-4.76,1.04-7.15,1.57-.01-.08-.03-.16-.04-.24l24.55-13.02.16.19c-1.43,1.36-2.86,2.72-4.35,4.14,4.09,3.31,8.57,4.15,13.26,2.79,5.85-1.7,9.32-5.87,10.62-12.29.39.9.81,1.74,1.2,2.55ZM240.66,6.17c2.19-1.05,6.89,2.57,6.7,5.28-2.92-.11-5.18-1.48-7-3.61-.24-.3-.01-1.52.3-1.67ZM236.31,14.54c-1.54,1.54-1.21,3.32.9,6.16-5.49-1.54-10.72-3-15.95-4.46.03-.17.07-.35.1-.52,2.43-.24,5.06-.28,5.67-3.36.39-1.94-.51-3.39-2.17-4.55,2.51.68,5.01,1.35,7.52,2.03,2.26.62,4.57,1.13,6.77,1.94,1.26.46,2.34,1.39,3.48,1.83-1.1-.18-2.23-.61-3.28-.46-1.08.15-2.29.64-3.04,1.39ZM243.02,19.76c3.02.35,11.2,8.77,12.25,12.7-4.84-3.4-8.69-7.74-12.25-12.7ZM271.35,48.21c-.99,2.02-.01,3.61,1.22,5.22-5.37-3.34-10.84-6.47-15.54-10.72.94.54,1.85,1.43,2.84,1.53,1.04.11,2.39-.23,3.21-.87,1.98-1.55,1.71-3.13-.61-7.24,4.91,3.25,9.83,6.5,14.74,9.76-2.44-.05-4.65-.17-5.86,2.32ZM267.38,32.23c4.46,2.84,9.48,4.89,13.41,9.32-2.49.4-12.99-7.11-13.41-9.32ZM284.99,50.83c3.61-1.39,15.07.42,17.7,2.77-5.94.19-11.65-.91-17.7-2.77ZM322.43,48.01c-2.55,1.22-3.64,2.83-3.16,4.68.58,2.26,2.21,3.21,5.16,3.2-6.25,1.93-12.54,3.69-19.16,4.1,2.4-.49,4.56-1.22,4.65-4.09.1-2.89-1.86-4.04-4.44-4.56,5.59-1.28,11.18-2.56,16.76-3.83.06.16.13.33.19.5ZM315.23,43.15c2.4-2.34,6.44-2.95,8.44-1.33-1.16,2.42-6.21,3.29-8.44,1.33ZM333.09,48.29c5.19-3.09,10.81-4.61,16.85-4.57-5.26,2.89-10.96,4.09-16.85,4.57ZM371.58,39.47l-15.81,9.08c-.12-.12-.24-.24-.36-.36,2.07-1.36,3.17-3.17,2.04-5.48-1.15-2.36-3.34-2.39-5.68-1.99,5.35-3.33,10.55-6.82,16.39-9.16-1.98,1.91-2.68,3.81-1.86,5.56.82,1.73,2.46,2.39,5.28,2.35ZM370.85,27.31c-2,.5-4.03.9-6.07,1.18-.43.06-1.37-.52-1.35-.76.03-.55.45-1.12.83-1.59.23-.28.67-.38,1.02-.57v-.42c1.79,0,3.58-.04,5.36.07.42.02.8.55,1.2.84-.33.43-.58,1.15-.99,1.25ZM378.71,29.44c4.29-4.26,9.38-7.12,15.26-8.59-4.37,4.11-9.65,6.64-15.26,8.59ZM391.92,14.77c-.33.39-1.13.37-1.71.54-.13-.58-.44-1.19-.34-1.73.4-2.33,2.42-4.9,4.89-6.03.17,0,.77.02,1.38.03-.03.62.17,1.4-.12,1.83-1.28,1.85-2.65,3.64-4.1,5.36ZM407.84,23.73c-1.86,1.82-5.89,3.26-8.87,1.19.94-1.27,2.06-2.44,2.73-3.83.31-.64-.06-1.82-.47-2.57-1.06-1.94-3.17-2.19-6.12-.83.01-3.35,2.27-5.98,5.73-6.88,3.25-.84,6.83.81,8.56,3.94,1.53,2.76.85,6.6-1.56,8.98Z"/> <path d="M418.39,22.56c-.9-2.12-3.08-3.99-2.86-6.3.6-6.24-1.96-9.26-5.87-10.8-5.59-2.76-10.79-2.48-15.59.89-5.16,3.63-6.9,8.92-5.88,15.06-3.44,1.79-6.77,3.46-10.03,5.27-1.04.58-1.67.45-2.57-.24-4.36-3.31-9.77-3.35-14.45-.38-2.92,1.85-5.92,3.61-8.99,5.2-4.67,2.41-8.51,5.37-9.23,11.06-.06.44-.81,1.01-1.34,1.15-2.64.72-5.32,1.29-7.97,1.98-1.09.28-1.8-.03-2.5-.87-3.33-4.01-7.59-5.28-12.62-4.14-3.55.8-7.1,1.63-10.65,2.41-4.53.99-8.9,2.23-11.5,6.61-.14.23-.76.32-1.12.26-3.14-.54-6.26-1.14-9.44-1.73-.4-4.66-2.91-7.77-6.66-10.13-3.81-2.39-7.54-4.92-11.29-7.41-2.5-1.65-5.47-2.9-8.14-1.91-3.92,1.46-5.66-.68-7.62-3.11-.53-.65-1.1-1.28-1.71-1.87-.91-.89-1.15-1.7-.63-3.04,2.56-6.58-1.25-14.13-8-16.06-4.78-1.36-9.57-2.67-14.37-3.94-6.58-1.74-12.14.91-14.99,7.05-.24.51-.79,1.18-1.25,1.23-1.63.18-3.26.33-4.89.46.01.52.01,1.04.01,1.56,4.44-1,8.77-1.17,13.19-.6-1.82,1.27-8.29,2.27-13.22,2.36-.04,1.47-.13,2.95-.23,4.43,4.6-.4,9.19-.79,13.79-1.19.01.08.02.15.03.23-2.2.7-4.39,1.39-6.62,2.09,1.3,2.68,3.69,4.83,6.67,5.69,5.33,1.55,10.69,3.06,16.09,4.37,1.72.42,3.61.13,5.84.18-1.34-2.39-2.39-4.26-3.44-6.13l.3-.23c5.72,6.3,11.43,12.61,17.15,18.91-.06.07-.12.13-.18.2-2.04-1.41-4.09-2.82-6.2-4.27-1.71,5.48.04,10.66,4.66,13.84,4.3,2.96,8.67,5.81,13.05,8.64,5.02,3.25,12.27,1.96,15.19-2.14-2.16-.92-4.3-1.83-6.44-2.74.05-.15.11-.3.16-.45,6.02,1.12,12.04,2.21,18.04,3.4.43.09.91.85,1.05,1.39,1.65,6.24,7.78,10.23,14.06,8.93,4.97-1.03,9.89-2.3,14.84-3.41,4.98-1.12,8.06-4.16,9.57-9.25-2.61.09-5,.18-7.4.27l-.02-.24,27-6.51c.05.15.09.31.14.46l-6.85,3.18c3.69,3.77,9.13,4.98,13.57,2.64,5.32-2.8,10.5-5.87,15.62-9.01,2.83-1.74,5.21-6.46,4.49-8.99-2.38.52-4.76,1.04-7.15,1.57-.01-.08-.03-.16-.04-.24l24.55-13.02.16.19c-1.43,1.36-2.86,2.72-4.35,4.14,4.09,3.31,8.57,4.15,13.26,2.79,5.85-1.7,9.32-5.87,10.62-12.29.39.9.81,1.74,1.2,2.55ZM240.66,6.17c2.19-1.05,6.89,2.57,6.7,5.28-2.92-.11-5.18-1.48-7-3.61-.24-.3-.01-1.52.3-1.67ZM236.31,14.54c-1.54,1.54-1.21,3.32.9,6.16-5.49-1.54-10.72-3-15.95-4.46.03-.17.07-.35.1-.52,2.43-.24,5.06-.28,5.67-3.36.39-1.94-.51-3.39-2.17-4.55,2.51.68,5.01,1.35,7.52,2.03,2.26.62,4.57,1.13,6.77,1.94,1.26.46,2.34,1.39,3.48,1.83-1.1-.18-2.23-.61-3.28-.46-1.08.15-2.29.64-3.04,1.39ZM243.02,19.76c3.02.35,11.2,8.77,12.25,12.7-4.84-3.4-8.69-7.74-12.25-12.7ZM271.35,48.21c-.99,2.02-.01,3.61,1.22,5.22-5.37-3.34-10.84-6.47-15.54-10.72.94.54,1.85,1.43,2.84,1.53,1.04.11,2.39-.23,3.21-.87,1.98-1.55,1.71-3.13-.61-7.24,4.91,3.25,9.83,6.5,14.74,9.76-2.44-.05-4.65-.17-5.86,2.32ZM267.38,32.23c4.46,2.84,9.48,4.89,13.41,9.32-2.49.4-12.99-7.11-13.41-9.32ZM284.99,50.83c3.61-1.39,15.07.42,17.7,2.77-5.94.19-11.65-.91-17.7-2.77ZM322.43,48.01c-2.55,1.22-3.64,2.83-3.16,4.68.58,2.26,2.21,3.21,5.16,3.2-6.25,1.93-12.54,3.69-19.16,4.1,2.4-.49,4.56-1.22,4.65-4.09.1-2.89-1.86-4.04-4.44-4.56,5.59-1.28,11.18-2.56,16.76-3.83.06.16.13.33.19.5ZM315.23,43.15c2.4-2.34,6.44-2.95,8.44-1.33-1.16,2.42-6.21,3.29-8.44,1.33ZM333.09,48.29c5.19-3.09,10.81-4.61,16.85-4.57-5.26,2.89-10.96,4.09-16.85,4.57ZM371.58,39.47l-15.81,9.08c-.12-.12-.24-.24-.36-.36,2.07-1.36,3.17-3.17,2.04-5.48-1.15-2.36-3.34-2.39-5.68-1.99,5.35-3.33,10.55-6.82,16.39-9.16-1.98,1.91-2.68,3.81-1.86,5.56.82,1.73,2.46,2.39,5.28,2.35ZM370.85,27.31c-2,.5-4.03.9-6.07,1.18-.43.06-1.37-.52-1.35-.76.03-.55.45-1.12.83-1.59.23-.28.67-.38,1.02-.57v-.42c1.79,0,3.58-.04,5.36.07.42.02.8.55,1.2.84-.33.43-.58,1.15-.99,1.25ZM378.71,29.44c4.29-4.26,9.38-7.12,15.26-8.59-4.37,4.11-9.65,6.64-15.26,8.59ZM391.92,14.77c-.33.39-1.13.37-1.71.54-.13-.58-.44-1.19-.34-1.73.4-2.33,2.42-4.9,4.89-6.03.17,0,.77.02,1.38.03-.03.62.17,1.4-.12,1.83-1.28,1.85-2.65,3.64-4.1,5.36ZM407.84,23.73c-1.86,1.82-5.89,3.26-8.87,1.19.94-1.27,2.06-2.44,2.73-3.83.31-.64-.06-1.82-.47-2.57-1.06-1.94-3.17-2.19-6.12-.83.01-3.35,2.27-5.98,5.73-6.88,3.25-.84,6.83.81,8.56,3.94,1.53,2.76.85,6.6-1.56,8.98Z"/>
<circle class="cls-1" cx="206.14" cy="15.03" r="8.22"/> <circle class="cls-1" cx="206.14" cy="15.03" r="8.22"/>
</svg> </svg>
\ No newline at end of file
...@@ -23,4 +23,4 @@ ...@@ -23,4 +23,4 @@
<polygon class="cls-1" points="538.46 21.84 538.46 39.47 503.18 0 484.49 0 484.49 63.62 502.88 63.62 502.88 24.15 538.26 63.62 556.85 63.62 556.85 21.84 538.46 21.84"/> <polygon class="cls-1" points="538.46 21.84 538.46 39.47 503.18 0 484.49 0 484.49 63.62 502.88 63.62 502.88 24.15 538.26 63.62 556.85 63.62 556.85 21.84 538.46 21.84"/>
<rect class="cls-2" x="538.46" width="18.39" height="14.25"/> <rect class="cls-2" x="538.46" width="18.39" height="14.25"/>
<path class="cls-1" d="M565.55,14.12V0h67.25v14.12h-23.48v49.5h-18.39V14.12h-25.38Z"/> <path class="cls-1" d="M565.55,14.12V0h67.25v14.12h-23.48v49.5h-18.39V14.12h-25.38Z"/>
</svg> </svg>
\ No newline at end of file
...@@ -40,4 +40,4 @@ For detailed guidance on testing, refer to the [`tests/README.md`](../tests/READ ...@@ -40,4 +40,4 @@ For detailed guidance on testing, refer to the [`tests/README.md`](../tests/READ
## Acknowledgments ## Acknowledgments
This contribution guide is adapted from [SGLang](https://docs.sglang.ai/references/contribution_guide.html). We thank them for the inspiration. This contribution guide is adapted from [SGLang](https://docs.sglang.ai/references/contribution_guide.html). We thank them for the inspiration.
\ No newline at end of file
...@@ -62,17 +62,17 @@ Then verify the Python version and installed PyTorch version: ...@@ -62,17 +62,17 @@ Then verify the Python version and installed PyTorch version:
Install PyTorch appropriate for your setup Install PyTorch appropriate for your setup
- **For most users**: - **For most users**:
```bash ```bash
"G:\ComfyuI\python\python.exe" -m pip install torch==2.6 torchvision==0.21 torchaudio==2.6 "G:\ComfyuI\python\python.exe" -m pip install torch==2.6 torchvision==0.21 torchaudio==2.6
``` ```
- **For RTX 50-series GPUs** (requires PyTorch ≥2.7 with CUDA 12.8): - **For RTX 50-series GPUs** (requires PyTorch ≥2.7 with CUDA 12.8):
```bash ```bash
"G:\ComfyuI\python\python.exe" -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 "G:\ComfyuI\python\python.exe" -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
``` ```
## Step 3: Install Nunchaku ## Step 3: Install Nunchaku
...@@ -108,55 +108,55 @@ You can also run a test (requires a Hugging Face token for downloading the model ...@@ -108,55 +108,55 @@ You can also run a test (requires a Hugging Face token for downloading the model
Please use CMD instead of PowerShell for building. Please use CMD instead of PowerShell for building.
- Step 1: Install Build Tools - Step 1: Install Build Tools
```bash ```bash
C:\Users\muyang\miniconda3\envs\comfyui\python.exe C:\Users\muyang\miniconda3\envs\comfyui\python.exe
"G:\ComfyuI\python\python.exe" -m pip install ninja setuptools wheel build "G:\ComfyuI\python\python.exe" -m pip install ninja setuptools wheel build
``` ```
- Step 2: Clone the Repository - Step 2: Clone the Repository
```bash ```bash
git clone https://github.com/mit-han-lab/nunchaku.git git clone https://github.com/mit-han-lab/nunchaku.git
cd nunchaku cd nunchaku
git submodule init git submodule init
git submodule update git submodule update
``` ```
- Step 3: Set Up Visual Studio Environment - Step 3: Set Up Visual Studio Environment
Locate the `VsDevCmd.bat` script on your system. Example path: Locate the `VsDevCmd.bat` script on your system. Example path:
``` ```
C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\VsDevCmd.bat C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\VsDevCmd.bat
``` ```
Then run: Then run:
```bash ```bash
"C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\VsDevCmd.bat" -startdir=none -arch=x64 -host_arch=x64 "C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\VsDevCmd.bat" -startdir=none -arch=x64 -host_arch=x64
set DISTUTILS_USE_SDK=1 set DISTUTILS_USE_SDK=1
``` ```
- Step 4: Build Nunchaku - Step 4: Build Nunchaku
```bash ```bash
"G:\ComfyuI\python\python.exe" setup.py develop "G:\ComfyuI\python\python.exe" setup.py develop
``` ```
Verify with: Verify with:
```bash ```bash
"G:\ComfyuI\python\python.exe" -c "import nunchaku" "G:\ComfyuI\python\python.exe" -c "import nunchaku"
``` ```
You can also run a test (requires a Hugging Face token for downloading the models): You can also run a test (requires a Hugging Face token for downloading the models):
```bash ```bash
"G:\ComfyuI\python\python.exe" -m huggingface-cli login "G:\ComfyuI\python\python.exe" -m huggingface-cli login
"G:\ComfyuI\python\python.exe" -m nunchaku.test "G:\ComfyuI\python\python.exe" -m nunchaku.test
``` ```
- (Optional) Step 5: Building wheel for Portable Python - (Optional) Step 5: Building wheel for Portable Python
If building directly with portable Python fails, you can first build the wheel in a working Conda environment, then install the `.whl` file using your portable Python: If building directly with portable Python fails, you can first build the wheel in a working Conda environment, then install the `.whl` file using your portable Python:
...@@ -182,42 +182,42 @@ Alternatively, install using [ComfyUI-Manager](https://github.com/Comfy-Org/Comf ...@@ -182,42 +182,42 @@ Alternatively, install using [ComfyUI-Manager](https://github.com/Comfy-Org/Comf
## 2. Download Models ## 2. Download Models
- **Standard FLUX.1-dev Models** - **Standard FLUX.1-dev Models**
Start by downloading the standard [FLUX.1-dev text encoders](https://huggingface.co/comfyanonymous/flux_text_encoders/tree/main) and [VAE](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/ae.safetensors). You can also optionally download the original [BF16 FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors) model. An example command: Start by downloading the standard [FLUX.1-dev text encoders](https://huggingface.co/comfyanonymous/flux_text_encoders/tree/main) and [VAE](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/ae.safetensors). You can also optionally download the original [BF16 FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors) model. An example command:
```bash ```bash
huggingface-cli download comfyanonymous/flux_text_encoders clip_l.safetensors --local-dir models/text_encoders huggingface-cli download comfyanonymous/flux_text_encoders clip_l.safetensors --local-dir models/text_encoders
huggingface-cli download comfyanonymous/flux_text_encoders t5xxl_fp16.safetensors --local-dir models/text_encoders huggingface-cli download comfyanonymous/flux_text_encoders t5xxl_fp16.safetensors --local-dir models/text_encoders
huggingface-cli download black-forest-labs/FLUX.1-schnell ae.safetensors --local-dir models/vae huggingface-cli download black-forest-labs/FLUX.1-schnell ae.safetensors --local-dir models/vae
huggingface-cli download black-forest-labs/FLUX.1-dev flux1-dev.safetensors --local-dir models/diffusion_models huggingface-cli download black-forest-labs/FLUX.1-dev flux1-dev.safetensors --local-dir models/diffusion_models
``` ```
- **SVDQuant 4-bit FLUX.1-dev Models** - **SVDQuant 4-bit FLUX.1-dev Models**
Next, download the SVDQuant 4-bit models: Next, download the SVDQuant 4-bit models:
- For **50-series GPUs**, use the [FP4 model](https://huggingface.co/mit-han-lab/svdq-fp4-flux.1-dev). - For **50-series GPUs**, use the [FP4 model](https://huggingface.co/mit-han-lab/svdq-fp4-flux.1-dev).
- For **other GPUs**, use the [INT4 model](https://huggingface.co/mit-han-lab/svdq-int4-flux.1-dev). - For **other GPUs**, use the [INT4 model](https://huggingface.co/mit-han-lab/svdq-int4-flux.1-dev).
Make sure to place the **entire downloaded folder** into `models/diffusion_models`. For example: Make sure to place the **entire downloaded folder** into `models/diffusion_models`. For example:
```bash ```bash
huggingface-cli download mit-han-lab/svdq-int4-flux.1-dev --local-dir models/diffusion_models/svdq-int4-flux.1-dev huggingface-cli download mit-han-lab/svdq-int4-flux.1-dev --local-dir models/diffusion_models/svdq-int4-flux.1-dev
``` ```
- **(Optional): Download Sample LoRAs** - **(Optional): Download Sample LoRAs**
You can test with some sample LoRAs like [FLUX.1-Turbo](https://huggingface.co/alimama-creative/FLUX.1-Turbo-Alpha/blob/main/diffusion_pytorch_model.safetensors) and [Ghibsky](https://huggingface.co/aleksa-codes/flux-ghibsky-illustration/blob/main/lora.safetensors). Place these files in the `models/loras` directory: You can test with some sample LoRAs like [FLUX.1-Turbo](https://huggingface.co/alimama-creative/FLUX.1-Turbo-Alpha/blob/main/diffusion_pytorch_model.safetensors) and [Ghibsky](https://huggingface.co/aleksa-codes/flux-ghibsky-illustration/blob/main/lora.safetensors). Place these files in the `models/loras` directory:
```bash ```bash
huggingface-cli download alimama-creative/FLUX.1-Turbo-Alpha diffusion_pytorch_model.safetensors --local-dir models/loras huggingface-cli download alimama-creative/FLUX.1-Turbo-Alpha diffusion_pytorch_model.safetensors --local-dir models/loras
huggingface-cli download aleksa-codes/flux-ghibsky-illustration lora.safetensors --local-dir models/loras huggingface-cli download aleksa-codes/flux-ghibsky-illustration lora.safetensors --local-dir models/loras
``` ```
## 3. Set Up Workflows ## 3. Set Up Workflows
To use the official workflows, download them from the [ComfyUI-nunchaku](https://github.com/mit-han-lab/ComfyUI-nunchaku/tree/main/workflows) and place them in your `ComfyUI/user/default/workflows` directory. The command can be To use the official workflows, download them from the [ComfyUI-nunchaku](https://github.com/mit-han-lab/ComfyUI-nunchaku/tree/main/workflows) and place them in your `ComfyUI/user/default/workflows` directory. The command can be
```bash ```bash
# From the root of your ComfyUI folder # From the root of your ComfyUI folder
...@@ -231,4 +231,4 @@ You can now launch ComfyUI and try running the example workflows. ...@@ -231,4 +231,4 @@ You can now launch ComfyUI and try running the example workflows.
If you encounter issues, refer to our: If you encounter issues, refer to our:
- [FAQs](https://github.com/mit-han-lab/nunchaku/discussions/262) - [FAQs](https://github.com/mit-han-lab/nunchaku/discussions/262)
- [GitHub Issues](https://github.com/mit-han-lab/nunchaku/issues) - [GitHub Issues](https://github.com/mit-han-lab/nunchaku/issues)
\ No newline at end of file
...@@ -4,7 +4,6 @@ from diffusers.models import FluxMultiControlNetModel ...@@ -4,7 +4,6 @@ from diffusers.models import FluxMultiControlNetModel
from diffusers.utils import load_image from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters.flux import apply_cache_on_pipe
from nunchaku.utils import get_gpu_memory, get_precision from nunchaku.utils import get_gpu_memory, get_precision
base_model = "black-forest-labs/FLUX.1-dev" base_model = "black-forest-labs/FLUX.1-dev"
...@@ -29,11 +28,6 @@ if need_offload: ...@@ -29,11 +28,6 @@ if need_offload:
else: else:
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
# apply_cache_on_pipe(
# pipeline, residual_diff_threshold=0.1
# ) # Uncomment this line to enable first-block cache to speedup generation
prompt = "A anime style girl with messy beach waves." prompt = "A anime style girl with messy beach waves."
control_image_depth = load_image( control_image_depth = load_image(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg" "https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg"
......
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.utils import get_precision
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.09,
residual_diff_threshold_single=0.12,
)
image = pipeline(["A cat holding a sign that says hello world"], num_inference_steps=50).images[0]
image.save(f"flux.1-dev-cache-{precision}.png")
from types import MethodType
import torch
from diffusers.utils import load_image
from nunchaku.models.pulid.pulid_forward import pulid_forward
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from nunchaku.pipeline.pipeline_flux_pulid import PuLIDFluxPipeline
from nunchaku.utils import get_precision
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = PuLIDFluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipeline.transformer.forward = MethodType(pulid_forward, pipeline.transformer)
id_image = load_image("https://github.com/ToTheBeginning/PuLID/blob/main/example_inputs/liuyifei.png?raw=true")
image = pipeline(
"A woman holding a sign that says 'SVDQuant is fast!",
id_image=id_image,
id_weight=1,
num_inference_steps=12,
guidance_scale=3.5,
).images[0]
image.save("flux.1-dev-pulid.png")
import time
import torch
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.teacache import TeaCache
from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
start_time = time.time()
with TeaCache(model=transformer, num_steps=50, rel_l1_thresh=0.3, enabled=True):
image = pipeline(
"A cat holding a sign that says hello world",
num_inference_steps=50,
guidance_scale=3.5,
height=1024,
width=1024,
generator=torch.Generator(device="cuda").manual_seed(0),
).images[0]
end_time = time.time()
print(f"Time taken: {(end_time - start_time)} seconds")
image.save(f"flux.1-dev-{precision}-tc.png")
from .models import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel, NunchakuT5EncoderModel from .models import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel, NunchakuT5EncoderModel
__all__ = ["NunchakuFluxTransformer2dModel", "NunchakuSanaTransformer2DModel", "NunchakuT5EncoderModel"]
__version__ = "0.3.0dev0" __version__ = "0.3.0dev1"
...@@ -7,16 +7,30 @@ from torch import nn ...@@ -7,16 +7,30 @@ from torch import nn
from ...caching import utils from ...caching import utils
def apply_cache_on_transformer(transformer: FluxTransformer2DModel, *, residual_diff_threshold=0.12): def apply_cache_on_transformer(
transformer: FluxTransformer2DModel,
*,
use_double_fb_cache: bool = False,
residual_diff_threshold: float = 0.12,
residual_diff_threshold_multi: float | None = None,
residual_diff_threshold_single: float = 0.1,
):
if residual_diff_threshold_multi is None:
residual_diff_threshold_multi = residual_diff_threshold
if getattr(transformer, "_is_cached", False): if getattr(transformer, "_is_cached", False):
transformer.cached_transformer_blocks[0].update_threshold(residual_diff_threshold) transformer.cached_transformer_blocks[0].update_residual_diff_threshold(
use_double_fb_cache, residual_diff_threshold_multi, residual_diff_threshold_single
)
return transformer return transformer
cached_transformer_blocks = nn.ModuleList( cached_transformer_blocks = nn.ModuleList(
[ [
utils.FluxCachedTransformerBlocks( utils.FluxCachedTransformerBlocks(
transformer=transformer, transformer=transformer,
residual_diff_threshold=residual_diff_threshold, use_double_fb_cache=use_double_fb_cache,
residual_diff_threshold_multi=residual_diff_threshold_multi,
residual_diff_threshold_single=residual_diff_threshold_single,
return_hidden_states_first=False, return_hidden_states_first=False,
) )
] ]
......
from types import MethodType
from typing import Any, Callable, Optional, Union
import numpy as np
import torch
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.utils import logging
from diffusers.utils.constants import USE_PEFT_BACKEND
from diffusers.utils.import_utils import is_torch_version
from diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers
from ..models.transformers import NunchakuFluxTransformer2dModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_steps: int = 0) -> Callable:
def teacache_forward(
self: Union[FluxTransformer2DModel, NunchakuFluxTransformer2dModel],
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
pooled_projections: torch.Tensor,
timestep: torch.LongTensor,
img_ids: torch.Tensor,
txt_ids: torch.Tensor,
guidance: torch.Tensor,
joint_attention_kwargs: Optional[dict[str, Any]] = None,
controlnet_block_samples: Optional[torch.Tensor] = None,
controlnet_single_block_samples: Optional[torch.Tensor] = None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000 # type: ignore
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) # type: ignore
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
inp = hidden_states.clone()
temb_ = temb.clone()
modulated_inp, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.transformer_blocks[0].norm1(inp, emb=temb_) # type: ignore
if self.cnt == 0 or self.cnt == num_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [
4.98651651e02,
-2.83781631e02,
5.58554382e01,
-3.82021401e00,
2.64230861e-01,
]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(
(
(modulated_inp - self.previous_modulated_input).abs().mean()
/ self.previous_modulated_input.abs().mean()
)
.cpu()
.item()
)
if self.accumulated_rel_l1_distance < rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt += 1
if self.cnt == num_steps:
self.cnt = 0
ckpt_kwargs: dict[str, Any]
if self.cnt > skip_steps:
if not should_calc:
hidden_states += self.previous_residual
else:
ori_hidden_states = hidden_states.clone()
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): # type: ignore
def custom_forward(*inputs): # type: ignore
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): # type: ignore
def custom_forward(*inputs): # type: ignore
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
self.previous_residual = hidden_states - ori_hidden_states
else:
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): # type: ignore
def custom_forward(*inputs): # type: ignore
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None): # type: ignore
def custom_forward(*inputs): # type: ignore
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output: torch.FloatTensor = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return output
return Transformer2DModelOutput(sample=output)
return teacache_forward
# A context manager to add teacache support to a block of code
# When the context manager is applied, the model passed to the context manager is modified
# to support teacache
class TeaCache:
def __init__(
self,
model: Union[FluxTransformer2DModel, NunchakuFluxTransformer2dModel],
num_steps: int = 50,
rel_l1_thresh: float = 0.6,
skip_steps: int = 0,
enabled: bool = True,
) -> None:
self.model = model
self.num_steps = num_steps
self.rel_l1_thresh = rel_l1_thresh
self.skip_steps = skip_steps
self.enabled = enabled
self.previous_model_forward = self.model.forward
def __enter__(self) -> "TeaCache":
if self.enabled:
# self.model.__class__.forward = make_teacache_forward(self.num_steps, self.rel_l1_thresh, self.skip_steps) # type: ignore
self.model.forward = MethodType(
make_teacache_forward(self.num_steps, self.rel_l1_thresh, self.skip_steps), self.model
)
self.model.cnt = 0
self.model.accumulated_rel_l1_distance = 0
self.model.previous_modulated_input = None
self.model.previous_residual = None
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if self.enabled:
self.model.forward = self.previous_model_forward
del self.model.cnt
del self.model.accumulated_rel_l1_distance
del self.model.previous_modulated_input
del self.model.previous_residual
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment