Commit 8c3e3c0b authored by muyangli's avatar muyangli
Browse files

[major] Support ComfyUI; Improve model loading;

parent dc5660f0
...@@ -72,13 +72,13 @@ In [example.py](example.py), we provide a minimal script for running INT4 FLUX.1 ...@@ -72,13 +72,13 @@ In [example.py](example.py), we provide a minimal script for running INT4 FLUX.1
```python ```python
import torch import torch
from diffusers import FluxPipeline
from nunchaku.pipelines import flux as nunchaku_flux from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
pipeline = nunchaku_flux.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
"black-forest-labs/FLUX.1-schnell", pipeline = FluxPipeline.from_pretrained(
torch_dtype=torch.bfloat16, "black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
qmodel_path="mit-han-lab/svdq-int4-flux.1-schnell", # download from Huggingface
).to("cuda") ).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0] image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0]
image.save("example.png") image.save("example.png")
...@@ -86,6 +86,10 @@ image.save("example.png") ...@@ -86,6 +86,10 @@ image.save("example.png")
Specifically, `nunchaku` shares the same APIs as [diffusers](https://github.com/huggingface/diffusers) and can be used in a similar way. The FLUX.1-dev model can be loaded in the same way by replace all `schnell` with `dev`. Specifically, `nunchaku` shares the same APIs as [diffusers](https://github.com/huggingface/diffusers) and can be used in a similar way. The FLUX.1-dev model can be loaded in the same way by replace all `schnell` with `dev`.
## ComfyUI
Please refer to [comfyui/README.md](comfyui/README.md) for the usage in [ComfyUI](https://github.com/comfyanonymous/ComfyUI).
## Gradio Demos ## Gradio Demos
### Text-to-Image ### Text-to-Image
...@@ -118,7 +122,7 @@ Please refer to [app/t2i/README.md](app/t2i/README.md) for instructions on repro ...@@ -118,7 +122,7 @@ Please refer to [app/t2i/README.md](app/t2i/README.md) for instructions on repro
## Roadmap ## Roadmap
- [ ] Easy installation - [ ] Easy installation
- [ ] Comfy UI node - [x] Comfy UI node
- [ ] Customized LoRA conversion instructions - [ ] Customized LoRA conversion instructions
- [ ] Customized model quantization instructions - [ ] Customized model quantization instructions
- [ ] Modularization - [ ] Modularization
......
import os from typing import Any, Callable
from typing import Any, Callable, Optional, Union
import torch import torch
import torchvision.utils import torchvision.utils
from diffusers import __version__
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, FluxPipelineOutput, FluxTransformer2DModel from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, FluxPipelineOutput, FluxTransformer2DModel
from einops import rearrange from einops import rearrange
from huggingface_hub import hf_hub_download, snapshot_download
from peft.tuners import lora from peft.tuners import lora
from PIL import Image from PIL import Image
from safetensors.torch import load_file
from torch import nn from torch import nn
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
from nunchaku.models.flux import inject_pipeline, load_quantized_model
from nunchaku.pipelines.flux import quantize_t5
class FluxPix2pixTurboPipeline(FluxPipeline): class FluxPix2pixTurboPipeline(FluxPipeline):
def update_alpha(self, alpha: float) -> None: def update_alpha(self, alpha: float) -> None:
...@@ -33,7 +26,7 @@ class FluxPix2pixTurboPipeline(FluxPipeline): ...@@ -33,7 +26,7 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
m.scaling["default_0"] = alpha m.scaling["default_0"] = alpha
else: else:
assert self.precision == "int4" assert self.precision == "int4"
transformer.nunchaku_set_lora_scale(alpha) transformer.set_lora_strength(alpha)
def load_control_module( def load_control_module(
self, self,
...@@ -62,7 +55,7 @@ class FluxPix2pixTurboPipeline(FluxPipeline): ...@@ -62,7 +55,7 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
self.load_lora_into_transformer(state_dict, {}, transformer=transformer) self.load_lora_into_transformer(state_dict, {}, transformer=transformer)
else: else:
assert svdq_lora_path is not None assert svdq_lora_path is not None
self.transformer.nunchaku_update_params(svdq_lora_path) self.transformer.update_lora_params(svdq_lora_path)
self.update_alpha(alpha) self.update_alpha(alpha)
@torch.no_grad() @torch.no_grad()
...@@ -214,69 +207,3 @@ class FluxPix2pixTurboPipeline(FluxPipeline): ...@@ -214,69 +207,3 @@ class FluxPix2pixTurboPipeline(FluxPipeline):
return (image,) return (image,)
return FluxPipelineOutput(images=image) return FluxPipelineOutput(images=image)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
qmodel_device = kwargs.pop("qmodel_device", "cuda:0")
qmodel_device = torch.device(qmodel_device)
if qmodel_device.type != "cuda":
raise ValueError(f"qmodel_device = {qmodel_device} is not a CUDA device")
qmodel_path = kwargs.pop("qmodel_path", None)
qencoder_path = kwargs.pop("qencoder_path", None)
if qmodel_path is None:
pipeline = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
pipeline.precision = "bf16"
else:
assert kwargs.pop("transformer", None) is None
assert isinstance(qmodel_path, str)
if not os.path.exists(qmodel_path):
qmodel_path = snapshot_download(qmodel_path)
config, unused_kwargs, commit_hash = FluxTransformer2DModel.load_config(
pretrained_model_name_or_path,
subfolder="transformer",
cache_dir=kwargs.get("cache_dir", None),
return_unused_kwargs=True,
return_commit_hash=True,
force_download=kwargs.get("force_download", False),
proxies=kwargs.get("proxies", None),
local_files_only=kwargs.get("local_files_only", None),
token=kwargs.get("token", None),
revision=kwargs.get("revision", None),
user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"},
**kwargs,
)
new_config = {k: v for k, v in config.items()}
new_config.update({"num_layers": 0, "num_single_layers": 0})
transformer: nn.Module = FluxTransformer2DModel.from_config(new_config).to(
kwargs.get("torch_dtype", torch.bfloat16)
)
state_dict = load_file(os.path.join(qmodel_path, "unquantized_layers.safetensors"))
transformer.load_state_dict(state_dict, strict=False)
pipeline = super().from_pretrained(pretrained_model_name_or_path, transformer=transformer, **kwargs)
m = load_quantized_model(
os.path.join(qmodel_path, "transformer_blocks.safetensors"),
0 if qmodel_device.index is None else qmodel_device.index,
)
inject_pipeline(pipeline, m, qmodel_device)
pipeline.precision = "int4"
transformer.config["num_layers"] = config["num_layers"]
transformer.config["num_single_layers"] = config["num_single_layers"]
if qencoder_path is not None:
assert isinstance(qencoder_path, str)
if not os.path.exists(qencoder_path):
hf_repo_id = os.path.dirname(qencoder_path)
filename = os.path.basename(qencoder_path)
qencoder_path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
quantize_t5(pipeline, qencoder_path)
return pipeline
...@@ -11,6 +11,7 @@ from PIL import Image ...@@ -11,6 +11,7 @@ from PIL import Image
from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline from flux_pix2pix_pipeline import FluxPix2pixTurboPipeline
from nunchaku.models.safety_checker import SafetyChecker from nunchaku.models.safety_checker import SafetyChecker
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
from utils import get_args from utils import get_args
from vars import DEFAULT_SKETCH_GUIDANCE, DEFAULT_STYLE_NAME, MAX_SEED, STYLE_NAMES, STYLES from vars import DEFAULT_SKETCH_GUIDANCE, DEFAULT_STYLE_NAME, MAX_SEED, STYLE_NAMES, STYLES
...@@ -24,18 +25,26 @@ args = get_args() ...@@ -24,18 +25,26 @@ args = get_args()
if args.precision == "bf16": if args.precision == "bf16":
pipeline = FluxPix2pixTurboPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) pipeline = FluxPix2pixTurboPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
pipeline.precision = "bf16"
pipeline.load_control_module( pipeline.load_control_module(
"mit-han-lab/svdquant-models", "flux.1-pix2pix-turbo-sketch2image.safetensors", alpha=DEFAULT_SKETCH_GUIDANCE "mit-han-lab/svdquant-models", "flux.1-pix2pix-turbo-sketch2image.safetensors", alpha=DEFAULT_SKETCH_GUIDANCE
) )
else: else:
assert args.precision == "int4" assert args.precision == "int4"
pipeline_init_kwargs = {}
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
pipeline_init_kwargs["transformer"] = transformer
if args.use_qencoder:
from nunchaku.models.text_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPix2pixTurboPipeline.from_pretrained( pipeline = FluxPix2pixTurboPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdq-int4-flux.1-schnell",
qencoder_path="mit-han-lab/svdquant-models/svdq-w4a16-t5.pt" if args.use_qencoder else None,
) )
pipeline = pipeline.to("cuda") pipeline = pipeline.to("cuda")
pipeline.precision = "int4"
pipeline.load_control_module( pipeline.load_control_module(
"mit-han-lab/svdquant-models", "mit-han-lab/svdquant-models",
"flux.1-pix2pix-turbo-sketch2image.safetensors", "flux.1-pix2pix-turbo-sketch2image.safetensors",
......
...@@ -89,10 +89,10 @@ def generate( ...@@ -89,10 +89,10 @@ def generate(
else: else:
assert precision == "int4" assert precision == "int4"
if lora_name != "None": if lora_name != "None":
pipeline.transformer.nunchaku_update_params(SVDQ_LORA_PATHS[lora_name]) pipeline.transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name])
pipeline.transformer.nunchaku_set_lora_scale(lora_weight) pipeline.transformer.set_lora_strength(lora_weight)
else: else:
pipeline.transformer.nunchaku_set_lora_scale(0) pipeline.transformer.set_lora_strength(0)
elif lora_name != "None": elif lora_name != "None":
if precision == "bf16": if precision == "bf16":
if pipeline.cur_lora_weight != lora_weight: if pipeline.cur_lora_weight != lora_weight:
...@@ -102,7 +102,7 @@ def generate( ...@@ -102,7 +102,7 @@ def generate(
m.scaling[lora_name] = lora_weight m.scaling[lora_name] = lora_weight
else: else:
assert precision == "int4" assert precision == "int4"
pipeline.transformer.nunchaku_set_lora_scale(lora_weight) pipeline.transformer.set_lora_strength(lora_weight)
pipeline.cur_lora_name = lora_name pipeline.cur_lora_name = lora_name
pipeline.cur_lora_weight = lora_weight pipeline.cur_lora_weight = lora_weight
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from diffusers import FluxPipeline from diffusers import FluxPipeline
from peft.tuners import lora from peft.tuners import lora
from nunchaku.pipelines import flux as nunchaku_flux from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
from vars import LORA_PATHS, SVDQ_LORA_PATHS from vars import LORA_PATHS, SVDQ_LORA_PATHS
...@@ -23,31 +23,37 @@ def get_pipeline( ...@@ -23,31 +23,37 @@ def get_pipeline(
lora_weight: float = 1, lora_weight: float = 1,
device: str | torch.device = "cuda", device: str | torch.device = "cuda",
) -> FluxPipeline: ) -> FluxPipeline:
pipeline_init_kwargs = {}
if model_name == "schnell": if model_name == "schnell":
if precision == "int4": if precision == "int4":
assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices" assert torch.device(device).type == "cuda", "int4 only supported on CUDA devices"
pipeline = nunchaku_flux.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
"black-forest-labs/FLUX.1-schnell", pipeline_init_kwargs["transformer"] = transformer
torch_dtype=torch.bfloat16, if use_qencoder:
qmodel_path="mit-han-lab/svdq-int4-flux.1-schnell", from nunchaku.models.text_encoder import NunchakuT5EncoderModel
qencoder_path="mit-han-lab/svdquant-models/svdq-w4a16-t5.pt" if use_qencoder else None,
qmodel_device=device, text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
) pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
else: else:
assert precision == "bf16" assert precision == "bf16"
pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
elif model_name == "dev": elif model_name == "dev":
if precision == "int4": if precision == "int4":
pipeline = nunchaku_flux.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-dev")
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdq-int4-flux.1-dev",
qencoder_path="mit-han-lab/svdquant-models/svdq-w4a16-t5.pt" if use_qencoder else None,
qmodel_device=device,
)
if lora_name not in ["All", "None"]: if lora_name not in ["All", "None"]:
pipeline.transformer.nunchaku_update_params(SVDQ_LORA_PATHS[lora_name]) transformer.update_lora_params(SVDQ_LORA_PATHS[lora_name])
pipeline.transformer.nunchaku_set_lora_scale(lora_weight) transformer.set_lora_strength(lora_weight)
pipeline_init_kwargs["transformer"] = transformer
if use_qencoder:
from nunchaku.models.text_encoder import NunchakuT5EncoderModel
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
)
else: else:
assert precision == "bf16" assert precision == "bf16"
pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
......
# SVDQuant ComfyUI Node
![comfyui](../assets/comfyui.jpg)
## Installation
1. Install `nunchaku` following [README.md](https://github.com/mit-han-lab/nunchaku?tab=readme-ov-file#installation).
2. Set up the dependencies for [ComfyUI](https://github.com/comfyanonymous/ComfyUI/tree/master) with the following commands:
```shell
git clone https://github.com/comfyanonymous/ComfyUI.git
cd ComfyUI
pip install -r requirements.txt
```
## Usage
1. **Set Up ComfyUI and SVDQuant**:
* Navigate to the root directory of ComfyUI and link (or copy) the [`nunchaku/comfyui`](./) folder to `custom_nodes/svdquant`.
* Place the SVDQuant workflow configurations from [`workflows`](./workflows) into `user/default/workflows`.
* For example
```shell
# Clone repositories (skip if already cloned)
git clone https://github.com/comfyanonymous/ComfyUI.git
git clone https://github.com/mit-han-lab/nunchaku.git
cd ComfyUI
# Copy workflow configurations
mkdir -p user/default/workflows
cp ../nunchaku/comfyui/workflows/* user/default/workflows/
# Add SVDQuant nodes
cd custom_nodes
ln -s ../../nunchaku/comfyui svdquant
```
2. **Download Required Models**: Follow [this tutorial](https://comfyanonymous.github.io/ComfyUI_examples/flux/) and download the required models into the appropriate directories using the commands below:
```shell
huggingface-cli download comfyanonymous/flux_text_encoders clip_l.safetensors --local-dir models/clip
huggingface-cli download comfyanonymous/flux_text_encoders t5xxl_fp16.safetensors --local-dir models/clip
huggingface-cli download black-forest-labs/FLUX.1-schnell ae.safetensors --local-dir models/vae
```
3. **Run ComfyUI**: From ComfyUI’s root directory, execute the following command to start the application:
```shell
python main.py
```
4. **Select the SVDQuant Workflow**: Choose one of the SVDQuant workflows (`flux.1-dev-svdquant.json` or `flux.1-schnell-svdquant.json`) to get started.
## SVDQuant Nodes
* **SVDQuant Flux DiT Loader**: A node for loading the FLUX diffusion model.
* `model_path`: Specifies the model location. It can be set to either `mit-han-lab/svdq-int-flux.1-schnell` or `mit-han-lab/svdq-int-flux.1-dev`. The model will automatically download from our Hugging Face repository.
* `device_id`: Indicates the GPU ID for running the model.
* **SVDQuant LoRA Loader**: A node for loading LoRA modules for SVDQuant diffusion models.
* Place your LoRA checkpoints in the `models/loras` directory. These will appear as selectable options under `lora_name`. **Ensure your LoRA checkpoints conform to the SVDQuant format. **A LoRA conversion script will be released soon. Meanwhile, example LoRAs are included and will automatically download from our Hugging Face repository when used.
* **Note**: Currently, only **one LoRA** can be loaded at a time.
* **SVDQuant Text Encoder Loader**: A node for loading the text encoders.
* For FLUX, use the following files:
- `text_encoder1`: `t5xxl_fp16.safetensors`
- `text_encoder2`: `clip_l.safetensors`
* **`t5_min_length`**: Sets the minimum sequence length for T5 text embeddings. The default in `DualCLIPLoader` is hardcoded to 256, but for better image quality in SVDQuant, use 512 here.
* **`t5_precision`**: Specifies the precision of the T5 text encoder. Choose `INT4` to use the INT4 text encoder, which reduces GPU memory usage by approximately 15GB. Please install [`deepcompressor`](https://github.com/mit-han-lab/deepcompressor) when using it:
```shell
git clone https://github.com/mit-han-lab/deepcompressor
cd deepcompressor
pip install poetry
poetry install
```
# only import if running as a custom node
try:
import comfy.utils
except ImportError:
pass
else:
from .nodes import NODE_CLASS_MAPPINGS
NODE_DISPLAY_NAME_MAPPINGS = {k: v.TITLE for k, v in NODE_CLASS_MAPPINGS.items()}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
import os
import types
import comfy.model_base
import comfy.model_patcher
import comfy.sd
import folder_paths
import GPUtil
import torch
from comfy.ldm.common_dit import pad_to_patch_size
from comfy.supported_models import Flux, FluxSchnell
from diffusers import FluxTransformer2DModel
from einops import rearrange, repeat
from torch import nn
from transformers import T5EncoderModel
from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
class ComfyUIFluxForwardWrapper(nn.Module):
def __init__(self, model: NunchakuFluxTransformer2dModel, config):
super(ComfyUIFluxForwardWrapper, self).__init__()
self.model = model
self.dtype = next(model.parameters()).dtype
self.config = config
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
assert control is None # for now
bs, c, h, w = x.shape
patch_size = self.config["patch_size"]
x = pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = (h + (patch_size // 2)) // patch_size
w_len = (w + (patch_size // 2)) // patch_size
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(
0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype
).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(
0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype
).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.model(
hidden_states=img,
encoder_hidden_states=context,
pooled_projections=y,
timestep=timestep,
img_ids=img_ids,
txt_ids=txt_ids,
guidance=guidance if self.config["guidance_embed"] else None,
).sample
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w]
return out
class SVDQuantFluxDiTLoader:
@classmethod
def INPUT_TYPES(s):
model_paths = ["mit-han-lab/svdq-int4-flux.1-schnell", "mit-han-lab/svdq-int4-flux.1-dev"]
ngpus = len(GPUtil.getGPUs())
return {
"required": {
"model_path": (model_paths,),
"device_id": (
"INT",
{"default": 0, "min": 0, "max": ngpus, "step": 1, "display": "number", "lazy": True},
),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_model"
CATEGORY = "SVDQuant"
TITLE = "SVDQuant Flux DiT Loader"
def load_model(self, model_path: str, device_id: int, **kwargs) -> tuple[FluxTransformer2DModel]:
device = f"cuda:{device_id}"
transformer = NunchakuFluxTransformer2dModel.from_pretrained(model_path).to(device)
dit_config = {
"image_model": "flux",
"in_channels": 16,
"patch_size": 2,
"out_channels": 16,
"vec_in_dim": 768,
"context_in_dim": 4096,
"hidden_size": 3072,
"mlp_ratio": 4.0,
"num_heads": 24,
"depth": 19,
"depth_single_blocks": 38,
"axes_dim": [16, 56, 56],
"theta": 10000,
"qkv_bias": True,
"disable_unet_model_creation": True,
}
if "schnell" in model_path:
dit_config["guidance_embed"] = False
model_config = FluxSchnell(dit_config)
else:
assert "dev" in model_path
dit_config["guidance_embed"] = True
model_config = Flux(dit_config)
model_config.set_inference_dtype(torch.bfloat16, None)
model_config.custom_operations = None
model = model_config.get_model({})
model.diffusion_model = ComfyUIFluxForwardWrapper(transformer, config=dit_config)
model = comfy.model_patcher.ModelPatcher(model, device, device_id)
return (model,)
def svdquant_t5_forward(
self: T5EncoderModel,
input_ids: torch.LongTensor,
attention_mask,
intermediate_output=None,
final_layer_norm_intermediate=True,
dtype: str | torch.dtype = torch.bfloat16,
):
assert attention_mask is None
assert intermediate_output is None
assert final_layer_norm_intermediate
outputs = self.encoder(input_ids, attention_mask=attention_mask)
hidden_states = outputs["last_hidden_state"]
hidden_states = hidden_states.to(dtype=dtype)
return hidden_states, None
class SVDQuantTextEncoderLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_type": (["flux"],),
"text_encoder1": (folder_paths.get_filename_list("text_encoders"),),
"text_encoder2": (folder_paths.get_filename_list("text_encoders"),),
"t5_min_length": (
"INT",
{"default": 512, "min": 256, "max": 1024, "step": 128, "display": "number", "lazy": True},
),
"t5_precision": (["BF16", "INT4"],),
}
}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_text_encoder"
CATEGORY = "SVDQuant"
TITLE = "SVDQuant Text Encoder Loader"
def load_text_encoder(
self, model_type: str, text_encoder1: str, text_encoder2: str, t5_min_length: int, t5_precision: str
):
text_encoder_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder1)
text_encoder_path2 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder2)
if model_type == "flux":
clip_type = comfy.sd.CLIPType.FLUX
else:
raise ValueError(f"Unknown type {model_type}")
clip = comfy.sd.load_clip(
ckpt_paths=[text_encoder_path1, text_encoder_path2],
embedding_directory=folder_paths.get_folder_paths("embeddings"),
clip_type=clip_type,
)
if model_type == "flux":
clip.tokenizer.t5xxl.min_length = t5_min_length
if t5_precision == "INT4":
from nunchaku.models.text_encoder import NunchakuT5EncoderModel
transformer = clip.cond_stage_model.t5xxl.transformer
param = next(transformer.parameters())
dtype = param.dtype
device = param.device
transformer = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
transformer.forward = types.MethodType(svdquant_t5_forward, transformer)
clip.cond_stage_model.t5xxl.transformer = (
transformer.to(device=device, dtype=dtype) if device.type == "cuda" else transformer
)
return (clip,)
class SVDQuantLoraLoader:
def __init__(self):
self.cur_lora_name = "None"
@classmethod
def INPUT_TYPES(s):
hf_lora_names = ["anime", "ghibsky", "realism", "yarn", "sketch"]
lora_name_list = [
"None",
*folder_paths.get_filename_list("loras"),
*[f"mit-han-lab/svdquant-models/svdq-flux.1-dev-lora-{n}.safetensors" for n in hf_lora_names],
]
return {
"required": {
"model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}),
"lora_name": (lora_name_list, {"tooltip": "The name of the LoRA."}),
"lora_strength": (
"FLOAT",
{
"default": 1.0,
"min": -100.0,
"max": 100.0,
"step": 0.01,
"tooltip": "How strongly to modify the diffusion model. This value can be negative.",
},
),
}
}
RETURN_TYPES = ("MODEL",)
OUTPUT_TOOLTIPS = ("The modified diffusion model.",)
FUNCTION = "load_lora"
TITLE = "SVDQuant LoRA Loader"
CATEGORY = "SVDQuant"
DESCRIPTION = (
"LoRAs are used to modify the diffusion model, "
"altering the way in which latents are denoised such as applying styles. "
"Currently, only one LoRA nodes can be applied."
)
def load_lora(self, model, lora_name: str, lora_strength: float):
if self.cur_lora_name == lora_name:
if self.cur_lora_name == "None":
pass # Do nothing since the lora is None
else:
model.model.diffusion_model.model.set_lora_strength(lora_strength)
else:
if lora_name == "None":
model.model.diffusion_model.model.set_lora_strength(0)
else:
try:
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
except FileNotFoundError:
lora_path = lora_name
model.model.diffusion_model.model.update_lora_params(lora_path)
model.model.diffusion_model.model.set_lora_strength(lora_strength)
self.cur_lora_name = lora_name
return (model,)
NODE_CLASS_MAPPINGS = {
"SVDQuantFluxDiTLoader": SVDQuantFluxDiTLoader,
"SVDQuantTextEncoderLoader": SVDQuantTextEncoderLoader,
"SVDQuantLoRALoader": SVDQuantLoraLoader,
}
This diff is collapsed.
{
"last_node_id": 29,
"last_link_id": 43,
"nodes": [
{
"id": 8,
"type": "VAEDecode",
"pos": [
1248,
192
],
"size": [
210,
46
],
"flags": {},
"order": 12,
"mode": 0,
"inputs": [
{
"name": "samples",
"type": "LATENT",
"link": 24
},
{
"name": "vae",
"type": "VAE",
"link": 12
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
9
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "VAEDecode"
},
"widgets_values": []
},
{
"id": 5,
"type": "EmptyLatentImage",
"pos": [
480,
432
],
"size": [
315,
106
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"links": [
23
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "EmptyLatentImage"
},
"widgets_values": [
1024,
1024,
1
],
"color": "#323",
"bgcolor": "#535"
},
{
"id": 16,
"type": "KSamplerSelect",
"pos": [
480,
720
],
"size": [
315,
58
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "SAMPLER",
"type": "SAMPLER",
"links": [
19
],
"shape": 3
}
],
"properties": {
"Node name for S&R": "KSamplerSelect"
},
"widgets_values": [
"euler"
]
},
{
"id": 17,
"type": "BasicScheduler",
"pos": [
480,
816
],
"size": [
315,
106
],
"flags": {},
"order": 8,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 41,
"slot_index": 0
}
],
"outputs": [
{
"name": "SIGMAS",
"type": "SIGMAS",
"links": [
20
],
"shape": 3
}
],
"properties": {
"Node name for S&R": "BasicScheduler"
},
"widgets_values": [
"simple",
4,
1
]
},
{
"id": 27,
"type": "Note",
"pos": [
480,
960
],
"size": [
311.3529052734375,
131.16229248046875
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [],
"outputs": [],
"properties": {
"text": ""
},
"widgets_values": [
"The schnell model is a distilled model that can generate a good image with only 4 steps."
],
"color": "#432",
"bgcolor": "#653"
},
{
"id": 22,
"type": "BasicGuider",
"pos": [
552.8497924804688,
128.6840362548828
],
"size": [
241.79998779296875,
46
],
"flags": {},
"order": 10,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 42,
"slot_index": 0
},
{
"name": "conditioning",
"type": "CONDITIONING",
"link": 40,
"slot_index": 1
}
],
"outputs": [
{
"name": "GUIDER",
"type": "GUIDER",
"links": [
30
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "BasicGuider"
},
"widgets_values": []
},
{
"id": 6,
"type": "CLIPTextEncode",
"pos": [
375,
221
],
"size": [
422.84503173828125,
164.31304931640625
],
"flags": {},
"order": 9,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 43
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
40
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"Pirate ship trapped in a cosmic maelstrom nebula, rendered in cosmic beach whirlpool engine, volumetric lighting, spectacular, ambient lights, light pollution, cinematic atmosphere, art nouveau style, illustration art artwork by SenseiJaye, intricate detail."
],
"color": "#232",
"bgcolor": "#353"
},
{
"id": 13,
"type": "SamplerCustomAdvanced",
"pos": [
842,
215
],
"size": [
355.20001220703125,
106
],
"flags": {},
"order": 11,
"mode": 0,
"inputs": [
{
"name": "noise",
"type": "NOISE",
"link": 37,
"slot_index": 0
},
{
"name": "guider",
"type": "GUIDER",
"link": 30,
"slot_index": 1
},
{
"name": "sampler",
"type": "SAMPLER",
"link": 19,
"slot_index": 2
},
{
"name": "sigmas",
"type": "SIGMAS",
"link": 20,
"slot_index": 3
},
{
"name": "latent_image",
"type": "LATENT",
"link": 23,
"slot_index": 4
}
],
"outputs": [
{
"name": "output",
"type": "LATENT",
"links": [
24
],
"slot_index": 0,
"shape": 3
},
{
"name": "denoised_output",
"type": "LATENT",
"links": null,
"shape": 3
}
],
"properties": {
"Node name for S&R": "SamplerCustomAdvanced"
},
"widgets_values": []
},
{
"id": 9,
"type": "SaveImage",
"pos": [
1569.9610595703125,
199.1280517578125
],
"size": [
985.3012084960938,
1060.3828125
],
"flags": {},
"order": 13,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 9
}
],
"outputs": [],
"properties": {},
"widgets_values": [
"ComfyUI"
]
},
{
"id": 25,
"type": "RandomNoise",
"pos": [
479.2310485839844,
589.0120239257812
],
"size": [
315,
82
],
"flags": {},
"order": 6,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "NOISE",
"type": "NOISE",
"links": [
37
],
"shape": 3
}
],
"properties": {
"Node name for S&R": "RandomNoise"
},
"widgets_values": [
45,
"fixed"
],
"color": "#2a363b",
"bgcolor": "#3f5159"
},
{
"id": 29,
"type": "SVDQuantTextEncoderLoader",
"pos": [
-40.45120620727539,
185.42774963378906
],
"size": [
352.79998779296875,
130
],
"flags": {},
"order": 7,
"mode": 0,
"inputs": [
{
"name": "t5_min_length",
"type": 0,
"link": null
}
],
"outputs": [
{
"name": "CLIP",
"type": "CLIP",
"links": [
43
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "SVDQuantTextEncoderLoader"
},
"widgets_values": [
"flux",
"t5xxl_fp16.safetensors",
"clip_l.safetensors",
512
]
},
{
"id": 10,
"type": "VAELoader",
"pos": [
-31.617252349853516,
377.54791259765625
],
"size": [
315,
58
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "VAE",
"type": "VAE",
"links": [
12
],
"slot_index": 0,
"shape": 3
}
],
"properties": {
"Node name for S&R": "VAELoader"
},
"widgets_values": [
"ae.safetensors"
]
},
{
"id": 26,
"type": "Note",
"pos": [
-28.286691665649414,
511.4660339355469
],
"size": [
336,
288
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [],
"properties": {
"text": ""
},
"widgets_values": [
"If you get an error in any of the nodes above make sure the files are in the correct directories.\n\nSee the top of the examples page for the links : https://comfyanonymous.github.io/ComfyUI_examples/flux/\n\nflux1-schnell.safetensors goes in: ComfyUI/models/unet/\n\nt5xxl_fp16.safetensors and clip_l.safetensors go in: ComfyUI/models/clip/\n\nae.safetensors goes in: ComfyUI/models/vae/\n\n\nTip: You can set the weight_dtype above to one of the fp8 types if you have memory issues."
],
"color": "#432",
"bgcolor": "#653"
},
{
"id": 28,
"type": "SVDQuantFluxDiTLoader",
"pos": [
-10.846628189086914,
890.9998779296875
],
"size": [
315,
82
],
"flags": {},
"order": 5,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"links": [
41,
42
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "SVDQuantFluxDiTLoader"
},
"widgets_values": [
"mit-han-lab/svdq-int4-flux.1-schnell",
0
]
}
],
"links": [
[
9,
8,
0,
9,
0,
"IMAGE"
],
[
12,
10,
0,
8,
1,
"VAE"
],
[
19,
16,
0,
13,
2,
"SAMPLER"
],
[
20,
17,
0,
13,
3,
"SIGMAS"
],
[
23,
5,
0,
13,
4,
"LATENT"
],
[
24,
13,
0,
8,
0,
"LATENT"
],
[
30,
22,
0,
13,
1,
"GUIDER"
],
[
37,
25,
0,
13,
0,
"NOISE"
],
[
40,
6,
0,
22,
1,
"CONDITIONING"
],
[
41,
28,
0,
17,
0,
"MODEL"
],
[
42,
28,
0,
22,
0,
"MODEL"
],
[
43,
29,
0,
6,
0,
"CLIP"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 0.6727499949325652,
"offset": [
405.6825017392191,
29.738440474209906
]
}
},
"version": 0.4
}
\ No newline at end of file
import torch import torch
from diffusers import FluxPipeline
from nunchaku.pipelines import flux as nunchaku_flux from nunchaku.models.transformer_flux import NunchakuFluxTransformer2dModel
pipeline = nunchaku_flux.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-schnell")
"black-forest-labs/FLUX.1-schnell", pipeline = FluxPipeline.from_pretrained(
torch_dtype=torch.bfloat16, "black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
qmodel_path="mit-han-lab/svdq-int4-flux.1-schnell", # download from Huggingface
).to("cuda") ).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0] image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0]
image.save("example.png") image.save("example.png")
__version__ = "0.0.1beta1" __version__ = "0.0.2beta0"
import os
import torch
from deepcompressor.backend.tinychat.linear import W4Linear
from huggingface_hub import constants, hf_hub_download
from safetensors.torch import load_file
from torch import nn
from transformers import PretrainedConfig, T5EncoderModel
def quantize_t5_encoder(
t5_encoder: nn.Module,
pretrained_model_name_or_path: str | os.PathLike,
cache_dir: str | os.PathLike | None = None,
force_download: bool = False,
local_files_only: bool = False,
token: str | bool | None = None,
revision: str = "main",
**kwargs,
):
subfolder = kwargs.get("subfolder", None)
if os.path.exists(pretrained_model_name_or_path):
dirname = (
pretrained_model_name_or_path
if subfolder is None
else os.path.join(pretrained_model_name_or_path, subfolder)
)
qmodel_path = os.path.join(dirname, "svdq-t5.safetensors")
else:
qmodel_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="svdq-t5.safetensors",
subfolder=subfolder,
repo_type="model",
revision=revision,
library_name=kwargs.get("library_name", None),
library_version=kwargs.get("library_version", None),
cache_dir=cache_dir,
local_dir=kwargs.get("local_dir", None),
user_agent=kwargs.get("user_agent", None),
force_download=force_download,
proxies=kwargs.get("proxies", None),
etag_timeout=kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
token=token,
local_files_only=local_files_only,
headers=kwargs.get("headers", None),
endpoint=kwargs.get("endpoint", None),
resume_download=kwargs.get("resume_download", None),
force_filename=kwargs.get("force_filename", None),
local_dir_use_symlinks=kwargs.get("local_dir_use_symlinks", "auto"),
)
state_dict = load_file(qmodel_path)
qlayer_suffix = tuple(kwargs.get("qlayer_suffix", (".q", ".k", ".v", ".o", ".wi_0")))
named_modules = {}
for name, module in t5_encoder.named_modules():
assert isinstance(name, str)
if isinstance(module, nn.Linear):
if f"{name}.qweight" in state_dict and name.endswith(qlayer_suffix):
print(f"Switching {name} to W4Linear")
qmodule = W4Linear.from_linear(module, group_size=128, init_only=True)
qmodule.qweight.data.copy_(state_dict[f"{name}.qweight"])
if qmodule.bias is not None:
qmodule.bias.data.copy_(state_dict[f"{name}.bias"])
qmodule.scales.data.copy_(state_dict[f"{name}.scales"])
qmodule.scaled_zeros.data.copy_(state_dict[f"{name}.scaled_zeros"])
# modeling_t5.py: T5DenseGatedActDense needs dtype of weight
qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device)
parent_name, child_name = name.rsplit(".", 1)
setattr(named_modules[parent_name], child_name, qmodule)
else:
named_modules[name] = module
return t5_encoder
class NunchakuT5EncoderModel(T5EncoderModel):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str | os.PathLike,
*model_args,
config: PretrainedConfig | str | os.PathLike | None = None,
cache_dir: str | os.PathLike | None = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: str | bool | None = None,
revision: str = "main",
use_safetensors: bool = None,
weights_only: bool = True,
**kwargs,
):
t5_encoder = (
super(NunchakuT5EncoderModel, cls)
.from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
weights_only=weights_only,
**kwargs,
)
.to(kwargs.get("torch_dtype", torch.bfloat16))
)
t5_encoder = quantize_t5_encoder(
t5_encoder=t5_encoder,
pretrained_model_name_or_path=pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
**kwargs,
)
return t5_encoder
import os import os
import types
import diffusers import diffusers
import torch import torch
from diffusers import FluxPipeline, FluxTransformer2DModel from diffusers import __version__, FluxTransformer2DModel
from huggingface_hub import hf_hub_download from diffusers.configuration_utils import register_to_config
from huggingface_hub import hf_hub_download, utils, constants
from packaging.version import Version from packaging.version import Version
from safetensors.torch import load_file
from torch import nn from torch import nn
from .._C import QuantizedFluxModel from .._C import QuantizedFluxModel
...@@ -13,9 +15,9 @@ from .._C import QuantizedFluxModel ...@@ -13,9 +15,9 @@ from .._C import QuantizedFluxModel
SVD_RANK = 32 SVD_RANK = 32
class NunchakuFluxModel(nn.Module): class NunchakuFluxTransformerBlocks(nn.Module):
def __init__(self, m: QuantizedFluxModel, device: torch.device): def __init__(self, m: QuantizedFluxModel, device: str | torch.device):
super().__init__() super(NunchakuFluxTransformerBlocks, self).__init__()
self.m = m self.m = m
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
self.device = device self.device = device
...@@ -87,9 +89,9 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: ...@@ -87,9 +89,9 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
return out.float() return out.float()
class EmbedND(torch.nn.Module): class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]): def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__() super(EmbedND, self).__init__()
self.dim = dim self.dim = dim
self.theta = theta self.theta = theta
self.axes_dim = axes_dim self.axes_dim = axes_dim
...@@ -102,7 +104,7 @@ class EmbedND(torch.nn.Module): ...@@ -102,7 +104,7 @@ class EmbedND(torch.nn.Module):
return emb.unsqueeze(1) return emb.unsqueeze(1)
def load_quantized_model(path: str, device: str | torch.device) -> QuantizedFluxModel: def load_quantized_module(path: str, device: str | torch.device = "cuda") -> QuantizedFluxModel:
device = torch.device(device) device = torch.device(device)
assert device.type == "cuda" assert device.type == "cuda"
...@@ -113,28 +115,118 @@ def load_quantized_model(path: str, device: str | torch.device) -> QuantizedFlux ...@@ -113,28 +115,118 @@ def load_quantized_model(path: str, device: str | torch.device) -> QuantizedFlux
return m return m
def inject_pipeline(pipe: FluxPipeline, m: QuantizedFluxModel, device: torch.device) -> FluxPipeline: class NunchakuFluxTransformer2dModel(FluxTransformer2DModel):
net: FluxTransformer2DModel = pipe.transformer @register_to_config
net.pos_embed = EmbedND(dim=net.inner_dim, theta=10000, axes_dim=[16, 56, 56]) def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: tuple[int] = (16, 56, 56),
):
super(NunchakuFluxTransformer2dModel, self).__init__(
patch_size=patch_size,
in_channels=in_channels,
num_layers=0,
num_single_layers=0,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
joint_attention_dim=joint_attention_dim,
pooled_projection_dim=pooled_projection_dim,
guidance_embeds=guidance_embeds,
axes_dims_rope=axes_dims_rope,
)
@classmethod
@utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs):
subfolder = kwargs.get("subfolder", None)
if os.path.exists(pretrained_model_name_or_path):
dirname = (
pretrained_model_name_or_path
if subfolder is None
else os.path.join(pretrained_model_name_or_path, subfolder)
)
unquantized_part_path = os.path.join(dirname, "unquantized_layers.safetensors")
transformer_block_path = os.path.join(dirname, "transformer_blocks.safetensors")
else:
download_kwargs = {
"subfolder": subfolder,
"repo_type": "model",
"revision": kwargs.get("revision", None),
"cache_dir": kwargs.get("cache_dir", None),
"local_dir": kwargs.get("local_dir", None),
"user_agent": kwargs.get("user_agent", None),
"force_download": kwargs.get("force_download", False),
"proxies": kwargs.get("proxies", None),
"etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
"token": kwargs.get("token", None),
"local_files_only": kwargs.get("local_files_only", None),
"headers": kwargs.get("headers", None),
"endpoint": kwargs.get("endpoint", None),
"resume_download": kwargs.get("resume_download", None),
"force_filename": kwargs.get("force_filename", None),
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
}
unquantized_part_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename="unquantized_layers.safetensors", **download_kwargs
)
transformer_block_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename="transformer_blocks.safetensors", **download_kwargs
)
config, _, _ = cls.load_config(
pretrained_model_name_or_path,
subfolder=subfolder,
cache_dir=kwargs.get("cache_dir", None),
return_unused_kwargs=True,
return_commit_hash=True,
force_download=kwargs.get("force_download", False),
proxies=kwargs.get("proxies", None),
local_files_only=kwargs.get("local_files_only", None),
token=kwargs.get("token", None),
revision=kwargs.get("revision", None),
user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"},
**kwargs,
)
device = kwargs.get("device", "cuda")
net.transformer_blocks = torch.nn.ModuleList([NunchakuFluxModel(m, device)]) transformer: NunchakuFluxTransformer2dModel = cls.from_config(config).to(
net.single_transformer_blocks = torch.nn.ModuleList([]) kwargs.get("torch_dtype", torch.bfloat16)
)
state_dict = load_file(unquantized_part_path)
transformer.load_state_dict(state_dict, strict=False)
m = load_quantized_module(transformer_block_path, device=device)
transformer.inject_quantized_module(m, device)
def update_params(self: FluxTransformer2DModel, path: str): return transformer
def update_lora_params(self, path: str):
if not os.path.exists(path): if not os.path.exists(path):
hf_repo_id = os.path.dirname(path) hf_repo_id = os.path.dirname(path)
filename = os.path.basename(path) filename = os.path.basename(path)
path = hf_hub_download(repo_id=hf_repo_id, filename=filename) path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
block = self.transformer_blocks[0] block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxModel) assert isinstance(block, NunchakuFluxTransformerBlocks)
block.m.load(path, True) block.m.load(path, True)
def set_lora_scale(self: FluxTransformer2DModel, scale: float): def set_lora_strength(self, strength: float = 1):
block = self.transformer_blocks[0] block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxModel) assert isinstance(block, NunchakuFluxTransformerBlocks)
block.m.setLoraScale(SVD_RANK, scale) block.m.setLoraScale(SVD_RANK, strength)
def inject_quantized_module(self, m: QuantizedFluxModel, device: str | torch.device = "cuda"):
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=[16, 56, 56])
net.nunchaku_update_params = types.MethodType(update_params, net) ### Compatible with the original forward method
net.nunchaku_set_lora_scale = types.MethodType(set_lora_scale, net) self.transformer_blocks = nn.ModuleList([NunchakuFluxTransformerBlocks(m, device)])
self.single_transformer_blocks = nn.ModuleList([])
return pipe return self
import os
import torch
from diffusers import __version__, FluxPipeline, FluxTransformer2DModel
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from torch import nn
from ..models.flux import inject_pipeline, load_quantized_model
def quantize_t5(pipe: FluxPipeline, qencoder_path: str):
assert os.path.exists(qencoder_path), f"qencoder_path {qencoder_path} does not exist"
from deepcompressor.backend.tinychat.linear import W4Linear
named_modules = {}
qencoder_state_dict = torch.load(qencoder_path, map_location="cpu")
for name, module in pipe.text_encoder_2.named_modules():
assert isinstance(name, str)
if isinstance(module, torch.nn.Linear):
suffix = [".q", ".k", ".v", ".o", ".wi_0"]
if f"{name}.qweight" in qencoder_state_dict and name.endswith(tuple(suffix)):
print(f"Switching {name} to W4Linear")
qmodule = W4Linear.from_linear(module, group_size=128, init_only=True)
qmodule.qweight.data.copy_(qencoder_state_dict[f"{name}.qweight"])
if qmodule.bias is not None:
qmodule.bias.data.copy_(qencoder_state_dict[f"{name}.bias"])
qmodule.scales.data.copy_(qencoder_state_dict[f"{name}.scales"])
qmodule.scaled_zeros.data.copy_(qencoder_state_dict[f"{name}.scaled_zeros"])
# modeling_t5.py: T5DenseGatedActDense needs dtype of weight
qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device)
parent_name, child_name = name.rsplit(".", 1)
setattr(named_modules[parent_name], child_name, qmodule)
else:
named_modules[name] = module
def from_pretrained(pretrained_model_name_or_path: str | os.PathLike, **kwargs) -> FluxPipeline:
qmodel_device = kwargs.pop("qmodel_device", "cuda:0")
qmodel_device = torch.device(qmodel_device)
if qmodel_device.type != "cuda":
raise ValueError(f"qmodel_device = {qmodel_device} is not a CUDA device")
qmodel_path = kwargs.pop("qmodel_path")
qencoder_path = kwargs.pop("qencoder_path", None)
if not os.path.exists(qmodel_path):
qmodel_path = snapshot_download(qmodel_path)
assert kwargs.pop("transformer", None) is None
config, unused_kwargs, commit_hash = FluxTransformer2DModel.load_config(
pretrained_model_name_or_path,
subfolder="transformer",
cache_dir=kwargs.get("cache_dir", None),
return_unused_kwargs=True,
return_commit_hash=True,
force_download=kwargs.get("force_download", False),
proxies=kwargs.get("proxies", None),
local_files_only=kwargs.get("local_files_only", None),
token=kwargs.get("token", None),
revision=kwargs.get("revision", None),
user_agent={"diffusers": __version__, "file_type": "model", "framework": "pytorch"},
**kwargs,
)
new_config = {k: v for k, v in config.items()}
new_config.update({"num_layers": 0, "num_single_layers": 0})
transformer: nn.Module = FluxTransformer2DModel.from_config(new_config).to(
kwargs.get("torch_dtype", torch.bfloat16)
)
state_dict = load_file(os.path.join(qmodel_path, "unquantized_layers.safetensors"))
transformer.load_state_dict(state_dict, strict=False)
pipeline = FluxPipeline.from_pretrained(pretrained_model_name_or_path, transformer=transformer, **kwargs)
m = load_quantized_model(
os.path.join(qmodel_path, "transformer_blocks.safetensors"),
0 if qmodel_device.index is None else qmodel_device.index,
)
inject_pipeline(pipeline, m, qmodel_device)
transformer.config["num_layers"] = config["num_layers"]
transformer.config["num_single_layers"] = config["num_single_layers"]
if qencoder_path is not None:
assert isinstance(qencoder_path, str)
if not os.path.exists(qencoder_path):
hf_repo_id = os.path.dirname(qencoder_path)
filename = os.path.basename(qencoder_path)
qencoder_path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
quantize_t5(pipeline, qencoder_path)
return pipeline
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment