Unverified Commit f86ad470 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

feat: pythonized model and QwenImage Support (#593)

* start refract the codebase

* update

* update

* start to implement ops

* add gemm

* write the docstrings

* define the w4a4 svdq linear

* update

* make the linter happy

* finished the SVDQW4A4Linear

* finished the SVDQW4A4Linear

* update

* update

* add a patcher to the model

* update

* add adanormsinglezero

* update

* update

* finished the naive implementation of nunchaku flux

* add ff

* finished the naive forward

* update

* svdq linear

* start debugging

* fix some issues

* successfully built the model

* update

* successfully load the model

* update

* update

* update

* try to making it runnable

* debugging

* debugging

* debugging

* add bias to awq linear

* run through

* fix the normalization

* update

* update

* update

* fix the attention

* fix the no fuse nvfp models

* update

* finished the fused ff

* make linter happy

* make linter happy

* make linter happy

* debugging the fp16 attn

* nunchaku fp16 is buggy

* finish the fp16 attn

* fp4 done

* fix the lora scales

* add a default value for alpha; need to debug int4

* fix input4

* update

* update

* ff does not work

* specialize the processors

* qwen transformer done. start debugging

* make linter happy

* add schnell v2 for metrics eval

* chore: schnellv2 eval

* update

* ff and attention correct

* need to check what happened to module

* fp4 done

* make linter happy

* update an example script

* reformat

* add an example script

* add the annoucement

* remove a misleading info

* ready to release
parent 954c7af9
...@@ -212,3 +212,7 @@ cython_debug/ ...@@ -212,3 +212,7 @@ cython_debug/
.gitattributes .gitattributes
nunchaku-models/ nunchaku-models/
*.png *.png
dev/
.tmp/
metrics.json
...@@ -15,17 +15,19 @@ Join our user groups on [**Slack**](https://join.slack.com/t/nunchaku/shared_inv ...@@ -15,17 +15,19 @@ Join our user groups on [**Slack**](https://join.slack.com/t/nunchaku/shared_inv
## News ## News
- **[2025-08-15]** 🔥 Our **4-bit Qwen-Image** models are now live on [Hugging Face](https://huggingface.co/nunchaku-tech/nunchaku-qwen-image)! Get started with our [example script](examples/v1/qwen-image.py). *ComfyUI, LoRA, and CPU offloading support are coming soon!*
- **[2025-08-15]** 🚀 The **Python backend** is now available! Explore our Pythonic FLUX models [here](nunchaku/models/transformers/transformer_flux_v2.py) and see the modular **4-bit linear layer** [here](nunchaku/models/linear.py).
- **[2025-07-31]** 🚀 **[FLUX.1-Krea-dev](https://www.krea.ai/blog/flux-krea-open-source-release) is now supported!** Check out our new [example script](./examples/flux.1-krea-dev.py) to get started. - **[2025-07-31]** 🚀 **[FLUX.1-Krea-dev](https://www.krea.ai/blog/flux-krea-open-source-release) is now supported!** Check out our new [example script](./examples/flux.1-krea-dev.py) to get started.
- **[2025-07-13]** 🚀 The official [**Nunchaku documentation**](https://nunchaku.tech/docs/nunchaku/) is now live! Explore comprehensive guides and resources to help you get started. - **[2025-07-13]** 🚀 The official [**Nunchaku documentation**](https://nunchaku.tech/docs/nunchaku/) is now live! Explore comprehensive guides and resources to help you get started.
- **[2025-06-29]** 🔥 Support **FLUX.1-Kontext**! Try out our [example script](./examples/flux.1-kontext-dev.py) to see it in action! Our demo is available at this [link](https://svdquant.mit.edu/kontext/)! - **[2025-06-29]** 🔥 Support **FLUX.1-Kontext**! Try out our [example script](./examples/flux.1-kontext-dev.py) to see it in action! Our demo is available at this [link](https://svdquant.mit.edu/kontext/)!
- **[2025-06-01]** 🚀 **Release v0.3.0!** This update adds support for multiple-batch inference, [**ControlNet-Union-Pro 2.0**](https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0), initial integration of [**PuLID**](https://github.com/ToTheBeginning/PuLID), and introduces [**Double FB Cache**](examples/flux.1-dev-double_cache.py). You can now load Nunchaku FLUX models as a single file, and our upgraded [**4-bit T5 encoder**](https://huggingface.co/nunchaku-tech/nunchaku-t5) now matches **FP8 T5** in quality! - **[2025-06-01]** 🚀 **Release v0.3.0!** This update adds support for multiple-batch inference, [**ControlNet-Union-Pro 2.0**](https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0), initial integration of [**PuLID**](https://github.com/ToTheBeginning/PuLID), and introduces [**Double FB Cache**](examples/flux.1-dev-double_cache.py). You can now load Nunchaku FLUX models as a single file, and our upgraded [**4-bit T5 encoder**](https://huggingface.co/nunchaku-tech/nunchaku-t5) now matches **FP8 T5** in quality!
- **[2025-04-16]** 🎥 Released tutorial videos in both [**English**](https://youtu.be/YHAVe-oM7U8?si=cM9zaby_aEHiFXk0) and [**Chinese**](https://www.bilibili.com/video/BV1BTocYjEk5/?share_source=copy_web&vd_source=8926212fef622f25cc95380515ac74ee) to assist installation and usage. - **[2025-04-16]** 🎥 Released tutorial videos in both [**English**](https://youtu.be/YHAVe-oM7U8?si=cM9zaby_aEHiFXk0) and [**Chinese**](https://www.bilibili.com/video/BV1BTocYjEk5/?share_source=copy_web&vd_source=8926212fef622f25cc95380515ac74ee) to assist installation and usage.
- **[2025-04-09]** 📢 Published the [April roadmap](https://github.com/nunchaku-tech/nunchaku/issues/266) and an [FAQ](https://github.com/nunchaku-tech/nunchaku/discussions/262) to help the community get started and stay up to date with Nunchaku’s development.
- **[2025-04-05]** 🚀 **Nunchaku v0.2.0 released!** This release brings [**multi-LoRA**](examples/flux.1-dev-multiple-lora.py) and [**ControlNet**](examples/flux.1-dev-controlnet-union-pro.py) support with even faster performance powered by [**FP16 attention**](#fp16-attention) and [**First-Block Cache**](#first-block-cache). We've also added compatibility for [**20-series GPUs**](examples/flux.1-dev-turing.py) — Nunchaku is now more accessible than ever!
<details> <details>
<summary>More</summary> <summary>More</summary>
- **[2025-04-09]** 📢 Published the [April roadmap](https://github.com/nunchaku-tech/nunchaku/issues/266) and an [FAQ](https://github.com/nunchaku-tech/nunchaku/discussions/262) to help the community get started and stay up to date with Nunchaku’s development.
- **[2025-04-05]** 🚀 **Nunchaku v0.2.0 released!** This release brings [**multi-LoRA**](examples/flux.1-dev-multiple-lora.py) and [**ControlNet**](examples/flux.1-dev-controlnet-union-pro.py) support with even faster performance powered by [**FP16 attention**](#fp16-attention) and [**First-Block Cache**](#first-block-cache). We've also added compatibility for [**20-series GPUs**](examples/flux.1-dev-turing.py) — Nunchaku is now more accessible than ever!
- **[2025-03-07]** 🚀 **Nunchaku v0.1.4 Released!** We've supported [4-bit text encoder and per-layer CPU offloading](#Low-Memory-Inference), reducing FLUX's minimum memory requirement to just **4 GiB** while maintaining a **2–3× speedup**. This update also fixes various issues related to resolution, LoRA, pin memory, and runtime stability. Check out the release notes for full details! - **[2025-03-07]** 🚀 **Nunchaku v0.1.4 Released!** We've supported [4-bit text encoder and per-layer CPU offloading](#Low-Memory-Inference), reducing FLUX's minimum memory requirement to just **4 GiB** while maintaining a **2–3× speedup**. This update also fixes various issues related to resolution, LoRA, pin memory, and runtime stability. Check out the release notes for full details!
- **[2025-02-20]** 🚀 **Support NVFP4 precision on NVIDIA RTX 5090!** NVFP4 delivers superior image quality compared to INT4, offering **~3× speedup** on the RTX 5090 over BF16. Learn more in our [blog](https://hanlab.mit.edu/blog/svdquant-nvfp4), checkout [`examples`](./examples) for usage and try [our demo](https://svdquant.mit.edu/flux1-schnell/) online! - **[2025-02-20]** 🚀 **Support NVFP4 precision on NVIDIA RTX 5090!** NVFP4 delivers superior image quality compared to INT4, offering **~3× speedup** on the RTX 5090 over BF16. Learn more in our [blog](https://hanlab.mit.edu/blog/svdquant-nvfp4), checkout [`examples`](./examples) for usage and try [our demo](https://svdquant.mit.edu/flux1-schnell/) online!
- **[2025-02-18]** 🔥 [**Customized LoRA conversion**](#Customized-LoRA) and [**model quantization**](#Customized-Model-Quantization) instructions are now available! **[ComfyUI](./comfyui)** workflows now support **customized LoRA**, along with **FLUX.1-Tools**! - **[2025-02-18]** 🔥 [**Customized LoRA conversion**](#Customized-LoRA) and [**model quantization**](#Customized-Model-Quantization) instructions are now available! **[ComfyUI](./comfyui)** workflows now support **customized LoRA**, along with **FLUX.1-Tools**!
......
...@@ -10,7 +10,12 @@ from utils import get_pipeline, hash_str_to_int ...@@ -10,7 +10,12 @@ from utils import get_pipeline, hash_str_to_int
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"-m", "--model", type=str, default="schnell", choices=["schnell", "dev"], help="Which FLUX.1 model to use" "-m",
"--model",
type=str,
default="schnell",
choices=["schnell", "schnell_v2", "dev"],
help="Which FLUX.1 model to use",
) )
parser.add_argument( parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precision to use" "-p", "--precision", type=str, default="int4", choices=["int4", "fp4", "bf16"], help="Which precision to use"
...@@ -33,6 +38,9 @@ def get_args(): ...@@ -33,6 +38,9 @@ def get_args():
default=0, default=0,
help="You will generate images for the subset specified by [chunk-start::chunk-step].", help="You will generate images for the subset specified by [chunk-start::chunk-step].",
) )
parser.add_argument(
"--max-dataset-size", type=int, default=5000, help="Maximum number of images to generate for each dataset"
)
known_args, _ = parser.parse_known_args() known_args, _ = parser.parse_known_args()
if known_args.model == "dev": if known_args.model == "dev":
...@@ -56,7 +64,7 @@ def main(): ...@@ -56,7 +64,7 @@ def main():
for dataset_name in args.datasets: for dataset_name in args.datasets:
output_dirname = os.path.join(output_root, dataset_name) output_dirname = os.path.join(output_root, dataset_name)
os.makedirs(output_dirname, exist_ok=True) os.makedirs(output_dirname, exist_ok=True)
dataset = get_dataset(name=dataset_name, max_dataset_size=8) dataset = get_dataset(name=dataset_name, max_dataset_size=args.max_dataset_size)
if args.chunk_step > 1: if args.chunk_step > 1:
dataset = dataset.select(range(args.chunk_start, len(dataset), args.chunk_step)) dataset = dataset.select(range(args.chunk_start, len(dataset), args.chunk_step))
for row in tqdm(dataset): for row in tqdm(dataset):
......
...@@ -13,6 +13,12 @@ def get_args(): ...@@ -13,6 +13,12 @@ def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("input_roots", type=str, nargs="*") parser.add_argument("input_roots", type=str, nargs="*")
parser.add_argument("-o", "--output-path", type=str, default="metrics.json", help="Image output path") parser.add_argument("-o", "--output-path", type=str, default="metrics.json", help="Image output path")
parser.add_argument(
"--max-dataset-size",
type=int,
default=1024,
help="Maximum number of images to compute metrics for each dataset",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -35,7 +41,7 @@ def main(): ...@@ -35,7 +41,7 @@ def main():
continue continue
print("Results for dataset:", dataset_name) print("Results for dataset:", dataset_name)
results[dataset_name] = {} results[dataset_name] = {}
dataset = get_dataset(name=dataset_name, return_gt=True) dataset = get_dataset(name=dataset_name, return_gt=True, max_dataset_size=args.max_dataset_size)
fid = compute_fid(ref_dirpath_or_dataset=dataset, gen_dirpath=os.path.join(image_root1, dataset_name)) fid = compute_fid(ref_dirpath_or_dataset=dataset, gen_dirpath=os.path.join(image_root1, dataset_name))
results[dataset_name]["fid"] = fid results[dataset_name]["fid"] = fid
print("FID:", fid) print("FID:", fid)
......
...@@ -54,7 +54,7 @@ def main(): ...@@ -54,7 +54,7 @@ def main():
prompt=dummy_prompt, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale prompt=dummy_prompt, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale
) )
torch.cuda.synchronize() torch.cuda.synchronize()
for _ in trange(args.test_times, desc="Warmup", position=0, leave=False): for _ in trange(args.test_times, desc="Test", position=0, leave=False):
start_time = time.time() start_time = time.time()
pipeline( pipeline(
prompt=dummy_prompt, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale prompt=dummy_prompt, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale
......
...@@ -4,6 +4,7 @@ from peft.tuners import lora ...@@ -4,6 +4,7 @@ from peft.tuners import lora
from vars import LORA_PATHS, SVDQ_LORA_PATHS from vars import LORA_PATHS, SVDQ_LORA_PATHS
from nunchaku import NunchakuFluxTransformer2dModel from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.models.transformers.transformer_flux_v2 import NunchakuFluxTransformer2DModelV2
def hash_str_to_int(s: str) -> int: def hash_str_to_int(s: str) -> int:
...@@ -49,6 +50,16 @@ def get_pipeline( ...@@ -49,6 +50,16 @@ def get_pipeline(
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, **pipeline_init_kwargs
) )
elif model_name == "schnell_v2":
transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
transformer=transformer,
torch_dtype=torch.bfloat16,
**pipeline_init_kwargs,
)
elif model_name == "dev": elif model_name == "dev":
if precision == "int4": if precision == "int4":
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
...@@ -93,6 +104,9 @@ def get_pipeline( ...@@ -93,6 +104,9 @@ def get_pipeline(
m.scaling[name] = lora_weight m.scaling[name] = lora_weight
else: else:
raise NotImplementedError(f"Model {model_name} not implemented") raise NotImplementedError(f"Model {model_name} not implemented")
if precision == "bf16":
pipeline.enable_model_cpu_offload()
else:
pipeline = pipeline.to(device) pipeline = pipeline.to(device)
return pipeline return pipeline
import torch
from nunchaku.models.transformers.transformer_qwenimage import NunchakuQwenImageTransformer2DModel
from nunchaku.pipeline.pipeline_qwenimage import NunchakuQwenImagePipeline
from nunchaku.utils import get_precision
model_name = "Qwen/Qwen-Image"
# Load the model
transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(
f"nunchaku-tech/nunchaku-qwen-image/svdq-{get_precision()}_r32-qwen-image.safetensors"
) # you can also use r128 model to improve the quality
# currently, you need to use this pipeline to offload the model to CPU
pipe = NunchakuQwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16)
positive_magic = {
"en": "Ultra HD, 4K, cinematic composition.", # for english prompt,
"zh": "超清,4K,电影级构图", # for chinese prompt,
}
# Generate image
prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""
negative_prompt = " " # using an empty string if you do not have specific concept to remove
image = pipe(
prompt=prompt + positive_magic["en"],
negative_prompt=negative_prompt,
width=1328,
height=1328,
num_inference_steps=50,
true_cfg_scale=4.0,
generator=torch.Generator().manual_seed(2333),
).images[0]
image.save("qwen-image.png")
...@@ -79,6 +79,36 @@ void gemm_w4a4(std::optional<torch::Tensor> act, // packed act [M, K ...@@ -79,6 +79,36 @@ void gemm_w4a4(std::optional<torch::Tensor> act, // packed act [M, K
// Tensor::synchronizeDevice(); // Tensor::synchronizeDevice();
} }
void quantize_w4a4_act_fuse_lora(std::optional<torch::Tensor> input,
std::optional<torch::Tensor> output,
std::optional<torch::Tensor> oscales,
std::optional<torch::Tensor> lora_down,
std::optional<torch::Tensor> lora_act_out,
std::optional<torch::Tensor> smooth,
bool fuse_glu,
bool fp4) {
spdlog::trace("running quantize_w4a4_act_fuse_lora: ");
auto getTensor = [](std::optional<torch::Tensor> &t) {
Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{};
if (ret.valid()) {
spdlog::trace(" {}", ret.shape.str());
} else {
spdlog::trace(" <invalid>");
}
return ret;
};
nunchaku::kernels::quantize_w4a4_act_fuse_lora(getTensor(input),
getTensor(output),
getTensor(oscales),
getTensor(lora_down),
getTensor(lora_act_out),
getTensor(smooth),
fuse_glu,
fp4);
}
void attention_fp16(torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM] void attention_fp16(torch::Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM] torch::Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM] torch::Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
......
...@@ -107,6 +107,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -107,6 +107,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def_submodule("ops") m.def_submodule("ops")
.def("gemm_w4a4", nunchaku::ops::gemm_w4a4) .def("gemm_w4a4", nunchaku::ops::gemm_w4a4)
.def("quantize_w4a4_act_fuse_lora", nunchaku::ops::quantize_w4a4_act_fuse_lora)
.def("attention_fp16", nunchaku::ops::attention_fp16) .def("attention_fp16", nunchaku::ops::attention_fp16)
.def("gemm_awq", nunchaku::ops::gemm_awq) .def("gemm_awq", nunchaku::ops::gemm_awq)
.def("gemv_awq", nunchaku::ops::gemv_awq) .def("gemv_awq", nunchaku::ops::gemv_awq)
......
import torch
from diffusers.models.activations import GELU
from diffusers.models.attention import FeedForward
from torch import nn
from ..ops.fused import fused_gelu_mlp
from .linear import SVDQW4A4Linear
class NunchakuBaseAttention(nn.Module):
def __init__(self, processor: str = "flashattn2", *args, **kwargs):
super(NunchakuBaseAttention, self).__init__()
self.processor = None
self.set_processor(processor)
def set_processor(self, processor: str):
raise NotImplementedError("Subclass must implement this method")
def _patch_linear(module: nn.Module, linear_cls, **kwargs) -> nn.Module:
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, linear_cls.from_linear(child, **kwargs))
else:
_patch_linear(child, linear_cls, **kwargs)
return module
class NunchakuFeedForward(FeedForward):
def __init__(self, ff: FeedForward, **kwargs):
super(FeedForward, self).__init__()
self.net = _patch_linear(ff.net, SVDQW4A4Linear, **kwargs)
# for int4, we shift the activation of mlp_fc2 to make it unsigned
self.net[2].act_unsigned = self.net[2].precision != "nvfp4"
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if isinstance(self.net[0], GELU):
return fused_gelu_mlp(hidden_states, self.net[0].proj, self.net[2])
else:
# fallback to original implementation
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
import math
from typing import Optional, Tuple
import torch
from torch.nn import functional as F
from ..._C.ops import attention_fp16
from ...ops.fused import fused_qkv_norm_rottary
class NunchakuFluxFA2Processor:
def __call__(
self,
attn,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor = None,
**kwargs,
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
# Adapted from https://github.com/huggingface/diffusers/blob/50dea89dc6036e71a00bc3d57ac062a80206d9eb/src/diffusers/models/attention_processor.py#L2275
if attention_mask is not None:
raise NotImplementedError("attention_mask is not supported")
batch_size, _, channels = hidden_states.shape
assert channels == attn.heads * attn.head_dim
qkv = fused_qkv_norm_rottary(
hidden_states,
attn.to_qkv,
attn.norm_q,
attn.norm_k,
image_rotary_emb[0] if isinstance(image_rotary_emb, tuple) else image_rotary_emb,
)
if attn.added_kv_proj_dim is not None:
assert encoder_hidden_states is not None
assert isinstance(image_rotary_emb, tuple)
qkv_context = fused_qkv_norm_rottary(
encoder_hidden_states, attn.add_qkv_proj, attn.norm_added_q, attn.norm_added_k, image_rotary_emb[1]
)
qkv = torch.cat([qkv_context, qkv], dim=1)
query, key, value = qkv.chunk(3, dim=-1)
query = query.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, attn.head_dim).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * attn.head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
# for single transformer block, we split the proj_out into two linear layers
hidden_states = attn.to_out(hidden_states)
return hidden_states
class NunchakuFluxFP16AttnProcessor:
def __init__(self, pad_size: int = 256):
self.pad_size = pad_size
def __call__(
self,
attn,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor = None,
**kwargs,
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
pad_size = self.pad_size
batch_size, _, channels = hidden_states.shape
assert channels == attn.heads * attn.head_dim
if encoder_hidden_states is None:
# single transformer block
assert attn.added_kv_proj_dim is None
num_tokens = hidden_states.shape[1]
num_tokens_pad = math.ceil(num_tokens / pad_size) * pad_size
query = torch.empty(
batch_size,
attn.heads,
num_tokens_pad,
attn.head_dim,
dtype=torch.float16,
device=hidden_states.device,
)
key = torch.empty_like(query)
value = torch.empty_like(query)
assert torch.is_tensor(image_rotary_emb)
fused_qkv_norm_rottary(
hidden_states,
attn.to_qkv,
attn.norm_q,
attn.norm_k,
image_rotary_emb,
output=(query, key, value),
attn_tokens=num_tokens,
)
else:
# joint transformer block
assert attn.added_kv_proj_dim is not None
num_txt_tokens = encoder_hidden_states.shape[1]
num_img_tokens = hidden_states.shape[1]
num_txt_tokens_pad = math.ceil(num_txt_tokens / pad_size) * pad_size
num_img_tokens_pad = math.ceil(num_img_tokens / pad_size) * pad_size
num_tokens_pad = num_txt_tokens_pad + num_img_tokens_pad
query = torch.empty(
batch_size,
attn.heads,
num_tokens_pad,
attn.head_dim,
dtype=torch.float16,
device=hidden_states.device,
)
key = torch.empty_like(query)
value = torch.empty_like(query)
assert isinstance(image_rotary_emb, tuple)
fused_qkv_norm_rottary(
hidden_states,
attn.to_qkv,
attn.norm_q,
attn.norm_k,
image_rotary_emb[0],
output=(
query[:, :, num_txt_tokens_pad:],
key[:, :, num_txt_tokens_pad:],
value[:, :, num_txt_tokens_pad:],
),
attn_tokens=num_img_tokens,
)
fused_qkv_norm_rottary(
encoder_hidden_states,
attn.add_qkv_proj,
attn.norm_added_q,
attn.norm_added_k,
image_rotary_emb[1],
output=(
query[:, :, :num_txt_tokens_pad],
key[:, :, :num_txt_tokens_pad],
value[:, :, :num_txt_tokens_pad],
),
attn_tokens=num_txt_tokens,
)
attention_output = torch.empty(
batch_size,
num_tokens_pad,
attn.heads * attn.head_dim,
dtype=hidden_states.dtype,
device=hidden_states.device,
)
attention_fp16(query, key, value, attention_output, attn.head_dim ** (-0.5))
hidden_states = attention_output
if encoder_hidden_states is None:
# for single transformer block, we split the proj_out into two linear layers
hidden_states = hidden_states[:, :num_tokens]
hidden_states = attn.to_out(hidden_states)
return hidden_states
else:
encoder_hidden_states, hidden_states = (
hidden_states[:, :num_txt_tokens],
hidden_states[:, num_txt_tokens_pad : num_txt_tokens_pad + num_img_tokens],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
from typing import Optional, Tuple
import torch
from diffusers.models.attention_dispatch import dispatch_attention_fn
from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb_qwen
class NunchakuQwenImageNaiveFA2Processor:
def __call__(
self,
attn,
hidden_states: torch.FloatTensor, # Image stream
encoder_hidden_states: torch.FloatTensor = None, # Text stream
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
# Adapted from https://github.com/huggingface/diffusers/blob/baa9b582f348e52aa2fc245e366611f454e1082b/src/diffusers/models/transformers/transformer_qwenimage.py#L246
if encoder_hidden_states is None:
raise ValueError("NunchakuQwenImageFA2Processor requires encoder_hidden_states (text stream)")
seq_txt = encoder_hidden_states.shape[1]
# TODO: fuse the QKV, norm and RoPE in a single kernel to boost the performance
# Compute QKV for image stream (sample projections)
img_qkv = attn.to_qkv(hidden_states)
img_query, img_key, img_value = img_qkv.chunk(3, dim=-1)
# Compute QKV for text stream (context projections)
txt_qkv = attn.add_qkv_proj(encoder_hidden_states)
txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1)
# Reshape for multi-head attention
img_query = img_query.unflatten(-1, (attn.heads, -1)) # [B, L, H, D]
img_key = img_key.unflatten(-1, (attn.heads, -1))
img_value = img_value.unflatten(-1, (attn.heads, -1))
txt_query = txt_query.unflatten(-1, (attn.heads, -1))
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
# Apply QK normalization
assert attn.norm_q is not None
img_query = attn.norm_q(img_query)
assert attn.norm_k is not None
img_key = attn.norm_k(img_key)
assert attn.norm_added_q is not None
txt_query = attn.norm_added_q(txt_query)
assert attn.norm_added_k is not None
txt_key = attn.norm_added_k(txt_key)
# Apply RoPE
if image_rotary_emb is not None:
img_freqs, txt_freqs = image_rotary_emb
img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
# Concatenate for joint attention
# Order: [text, image]
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
# Compute joint attention
joint_hidden_states = dispatch_attention_fn(
joint_query,
joint_key,
joint_value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=None,
)
# joint_query = joint_query.transpose(1, 2)
# joint_key = joint_key.transpose(1, 2)
# joint_value = joint_value.transpose(1, 2)
# joint_hidden_states = F.scaled_dot_product_attention(
# joint_query, joint_key, joint_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
# )
# joint_hidden_states = joint_hidden_states.transpose(1, 2)
# Reshape back
joint_hidden_states = joint_hidden_states.flatten(2, 3)
joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
# Split attention outputs back
txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
# Apply output projections
img_attn_output = attn.to_out[0](img_attn_output)
if len(attn.to_out) > 1:
img_attn_output = attn.to_out[1](img_attn_output) # dropout
txt_attn_output = attn.to_add_out(txt_attn_output)
return img_attn_output, txt_attn_output
import diffusers
import torch
from packaging.version import Version
from torch import nn
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
"""
Rotary positional embedding function.
Copied from https://github.com/huggingface/diffusers/blob/c9ff360966327ace3faad3807dc871a4e5447501/src/diffusers/models/transformers/transformer_flux.py#L38
Parameters
----------
pos : torch.Tensor
Position tensor of shape (..., n).
dim : int
Embedding dimension (must be even).
theta : int
Rotary base.
Returns
-------
torch.Tensor
Rotary embedding tensor.
"""
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
USE_SINCOS = True
if USE_SINCOS:
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 1, 2)
else:
out = out.view(batch_size, -1, dim // 2, 1, 1)
return out.float()
class NunchakuFluxPosEmbed(nn.Module):
"""
Multi-dimensional rotary embedding module.
Adapted from https://github.com/huggingface/diffusers/blob/c9ff360966327ace3faad3807dc871a4e5447501/src/diffusers/models/transformers/transformer_flux.py#L55
Parameters
----------
dim : int
Embedding dimension.
theta : int
Rotary base.
axes_dim : list[int]
List of axis dimensions for each spatial axis.
"""
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super(NunchakuFluxPosEmbed, self).__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
"""
Computes rotary embeddings for multi-dimensional positions.
Parameters
----------
ids : torch.Tensor
Position indices tensor of shape (..., n_axes).
Returns
-------
torch.Tensor
Rotary embedding tensor.
"""
if Version(diffusers.__version__) >= Version("0.31.0"):
ids = ids[None, ...]
n_axes = ids.shape[-1]
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
return emb.unsqueeze(1)
def pack_rotemb(rotemb: torch.Tensor) -> torch.Tensor:
"""
Packs rotary embeddings for efficient computation.
Parameters
----------
rotemb : torch.Tensor
Rotary embedding tensor of shape (B, M, D//2, 1, 2), dtype float32.
Returns
-------
torch.Tensor
Packed rotary embedding tensor of shape (B, M, D).
"""
assert rotemb.dtype == torch.float32
B = rotemb.shape[0]
M = rotemb.shape[1]
D = rotemb.shape[2] * 2
assert rotemb.shape == (B, M, D // 2, 1, 2)
assert M % 16 == 0
assert D % 8 == 0
rotemb = rotemb.reshape(B, M // 16, 16, D // 8, 8)
rotemb = rotemb.permute(0, 1, 3, 2, 4)
# 16*8 pack, FP32 accumulator (C) format
# https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-c
##########################################|--M--|--D--|
##########################################|-3--4--5--6|
########################################## : : : :
rotemb = rotemb.reshape(*rotemb.shape[0:3], 2, 8, 4, 2)
rotemb = rotemb.permute(0, 1, 2, 4, 5, 3, 6)
rotemb = rotemb.contiguous()
rotemb = rotemb.view(B, M, D)
return rotemb
import torch
from torch import nn
from ..ops.gemm import svdq_gemm_w4a4_cuda
from ..ops.gemv import awq_gemv_w4a16_cuda
from ..ops.quantize import svdq_quantize_w4a4_act_fuse_lora_cuda
class SVDQW4A4Linear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 32,
bias: bool = True,
precision: str = "int4",
torch_dtype: torch.dtype = torch.bfloat16,
device: str | torch.device = "cpu",
):
super(SVDQW4A4Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.rank = rank
self.precision = precision
self.torch_dtype = torch_dtype
self.group_size = None
if precision == "nvfp4":
self.group_size = 16
elif precision == "int4":
self.group_size = 64
else:
raise ValueError(f"Invalid precision: {precision}")
self.qweight = nn.Parameter(
torch.empty(out_features, in_features // 2, dtype=torch.int8, device=device), requires_grad=False
)
self.bias = (
nn.Parameter(torch.empty(out_features, dtype=torch_dtype, device=device), requires_grad=True)
if bias
else None
)
self.wscales = nn.Parameter(
torch.empty(
in_features // self.group_size,
out_features,
dtype=torch_dtype if precision == "int4" else torch.float8_e4m3fn,
device=device,
),
requires_grad=False,
)
self.smooth_factor = nn.Parameter(
torch.empty(in_features, dtype=torch_dtype, device=device), requires_grad=False
)
self.smooth_factor_orig = nn.Parameter(
torch.empty(in_features, dtype=torch_dtype, device=device), requires_grad=False
)
self.proj_down = nn.Parameter(torch.empty(in_features, rank, dtype=torch_dtype, device=device))
self.proj_up = nn.Parameter(torch.empty(out_features, rank, dtype=torch_dtype, device=device))
self.wtscale = None
self.wcscales = None
if precision == "nvfp4":
self.wtscale = nn.Parameter(torch.ones(1, dtype=torch_dtype, device=device), requires_grad=False)
self.wcscales = nn.Parameter(
torch.ones(out_features, dtype=torch_dtype, device=device), requires_grad=False
)
self.act_unsigned = False
@classmethod
def from_linear(cls, linear: nn.Linear, **kwargs):
in_features = kwargs.pop("in_features", linear.in_features)
return cls(
in_features=in_features,
out_features=linear.out_features,
bias=linear.bias is not None,
torch_dtype=linear.weight.dtype,
device=linear.weight.device,
**kwargs,
)
def forward(self, x: torch.Tensor, output: torch.Tensor | None = None) -> torch.Tensor:
# quantize the input run the down projection
batch_size, seq_len, channels = x.shape
x = x.view(batch_size * seq_len, channels)
if output is None:
output = torch.empty(batch_size * seq_len, self.out_features, dtype=x.dtype, device=x.device)
quantized_x, ascales, lora_act_out = self.quantize(x)
output = self.forward_quant(quantized_x, ascales, lora_act_out, output)
output = output.view(batch_size, seq_len, -1)
return output
def quantize(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
quantized_x, ascales, lora_act_out = svdq_quantize_w4a4_act_fuse_lora_cuda(
x, lora_down=self.proj_down, smooth=self.smooth_factor, fp4=self.precision == "nvfp4"
)
return quantized_x, ascales, lora_act_out
def forward_quant(
self,
quantized_x: torch.Tensor,
ascales: torch.Tensor,
lora_act: torch.Tensor,
output: torch.Tensor | None = None,
) -> torch.Tensor:
if output is None:
output = torch.empty(
quantized_x.shape[0], self.out_features, dtype=self.proj_up.dtype, device=quantized_x.device
)
svdq_gemm_w4a4_cuda(
act=quantized_x,
wgt=self.qweight,
out=output,
ascales=ascales,
wscales=self.wscales,
lora_act_in=lora_act,
lora_up=self.proj_up,
bias=self.bias,
fp4=self.precision == "nvfp4",
alpha=self.wtscale,
wcscales=self.wcscales,
act_unsigned=self.act_unsigned,
)
return output
def __repr__(self):
return f"SVDQW4A4Linear(in_features={self.in_features}, out_features={self.out_features}, rank={self.rank}, precision={self.precision}, act_unsigned={self.act_unsigned})"
class AWQW4A16Linear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
group_size: int = 64,
torch_dtype: torch.dtype = torch.bfloat16,
device: str | torch.device = "cuda",
):
super(AWQW4A16Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size
self.qweight = nn.Parameter(
torch.empty(out_features // 4, in_features // 2, dtype=torch.int32, device=device), requires_grad=False
)
self.bias = (
nn.Parameter(torch.empty(out_features, dtype=torch_dtype, device=device), requires_grad=True)
if bias
else None
)
self.wscales = nn.Parameter(
torch.empty(in_features // self.group_size, out_features, dtype=torch_dtype, device=device),
requires_grad=False,
)
self.wzeros = nn.Parameter(
torch.empty(in_features // self.group_size, out_features, dtype=torch_dtype, device=device),
requires_grad=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = awq_gemv_w4a16_cuda(
in_feats=x,
kernel=self.qweight,
scaling_factors=self.wscales,
zeros=self.wzeros,
m=x.shape[0],
n=self.out_features,
k=self.in_features,
group_size=self.group_size,
)
if self.bias is not None:
view_shape = [1] * (output.ndim - 1) + [-1]
output.add_(self.bias.view(view_shape))
return output
@classmethod
def from_linear(
cls,
linear: nn.Linear,
group_size: int = 64,
torch_dtype: torch.dtype = torch.bfloat16,
device: str = "cpu",
**kwargs,
):
return cls(
in_features=linear.in_features,
out_features=linear.out_features,
bias=linear.bias is not None,
group_size=group_size,
torch_dtype=torch_dtype,
device=device,
)
def __repr__(self):
return f"AWQW4A16Linear(in_features={self.in_features}, out_features={self.out_features}, group_size={self.group_size})"
from typing import Optional, Tuple
import torch
from diffusers.models.normalization import AdaLayerNormZero, AdaLayerNormZeroSingle
from .linear import AWQW4A16Linear
class NunchakuAdaLayerNormZero(AdaLayerNormZero):
def __init__(self, other: AdaLayerNormZero, scale_shift: float = 1.0):
super(AdaLayerNormZero, self).__init__()
self.scale_shift = scale_shift
self.emb = other.emb
self.silu = other.silu
self.linear = AWQW4A16Linear.from_linear(other.linear)
self.norm = other.norm
def forward(
self,
x: torch.Tensor,
timestep: Optional[torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
hidden_dtype: Optional[torch.dtype] = None,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if self.emb is not None:
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
emb = self.linear(self.silu(emb))
# The weight layout has changed; use split_mod rather than chunk to separate the embedding.
emb = emb.view(emb.shape[0], -1, 6).permute(2, 0, 1)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb
norm_x = self.norm(x)
if self.scale_shift != 0:
scale_msa.add_(self.scale_shift)
scale_mlp.add_(self.scale_shift)
norm_x_scaled = norm_x * scale_msa[:, None] + shift_msa[:, None]
return norm_x_scaled, gate_msa, shift_mlp, scale_mlp, gate_mlp
class NunchakuAdaLayerNormZeroSingle(AdaLayerNormZeroSingle):
def __init__(self, other: AdaLayerNormZeroSingle, scale_shift: float = 1.0):
super(AdaLayerNormZeroSingle, self).__init__()
self.scale_shift = scale_shift
self.silu = other.silu
self.linear = AWQW4A16Linear.from_linear(other.linear)
self.norm = other.norm
def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
# The weight layout has changed; use split_mod rather than chunk to separate the embedding.
emb = emb.view(emb.shape[0], -1, 3).permute(2, 0, 1)
shift_msa, scale_msa, gate_msa = emb
if self.scale_shift != 0:
scale_msa.add_(self.scale_shift)
norm_x = self.norm(x)
norm_x_scaled = norm_x * scale_msa[:, None] + shift_msa[:, None]
return norm_x_scaled, gate_msa
...@@ -30,7 +30,6 @@ try: ...@@ -30,7 +30,6 @@ try:
import xformers.ops as xops import xformers.ops as xops
except ImportError: except ImportError:
xops = None xops = None
print("Please 'pip install xformers'")
class LayerNorm(nn.LayerNorm): class LayerNorm(nn.LayerNorm):
......
This diff is collapsed.
import json
import os
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import torch
from diffusers.models.attention_processor import Attention
from diffusers.models.transformers.transformer_qwenimage import (
QwenEmbedRope,
QwenImageTransformer2DModel,
QwenImageTransformerBlock,
)
from huggingface_hub import utils
from ...utils import get_precision
from ..attention import NunchakuBaseAttention, NunchakuFeedForward
from ..attention_processors.qwenimage import NunchakuQwenImageNaiveFA2Processor
from ..linear import AWQW4A16Linear, SVDQW4A4Linear
from ..utils import fuse_linears
from .utils import NunchakuModelLoaderMixin
class NunchakuQwenAttention(NunchakuBaseAttention):
def __init__(self, other: Attention, processor: str = "flashattn2", **kwargs):
super(NunchakuQwenAttention, self).__init__(processor)
self.inner_dim = other.inner_dim
self.inner_kv_dim = other.inner_kv_dim
self.query_dim = other.query_dim
self.use_bias = other.use_bias
self.is_cross_attention = other.is_cross_attention
self.cross_attention_dim = other.cross_attention_dim
self.upcast_attention = other.upcast_attention
self.upcast_softmax = other.upcast_softmax
self.rescale_output_factor = other.rescale_output_factor
self.residual_connection = other.residual_connection
self.dropout = other.dropout
self.fused_projections = other.fused_projections
self.out_dim = other.out_dim
self.out_context_dim = other.out_context_dim
self.context_pre_only = other.context_pre_only
self.pre_only = other.pre_only
self.is_causal = other.is_causal
self.scale_qk = other.scale_qk
self.scale = other.scale
self.heads = other.heads
self.sliceable_head_dim = other.sliceable_head_dim
self.added_kv_proj_dim = other.added_kv_proj_dim
self.only_cross_attention = other.only_cross_attention
self.group_norm = other.group_norm
self.spatial_norm = other.spatial_norm
self.norm_cross = other.norm_cross
self.norm_q = other.norm_q
self.norm_k = other.norm_k
self.norm_added_q = other.norm_added_q
self.norm_added_k = other.norm_added_k
# fuse the qkv
with torch.device("meta"):
to_qkv = fuse_linears([other.to_q, other.to_k, other.to_v])
self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, **kwargs)
self.to_out = other.to_out
self.to_out[0] = SVDQW4A4Linear.from_linear(self.to_out[0], **kwargs)
assert self.added_kv_proj_dim is not None
# fuse the add_qkv
with torch.device("meta"):
add_qkv_proj = fuse_linears([other.add_q_proj, other.add_k_proj, other.add_v_proj])
self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, **kwargs)
self.to_add_out = SVDQW4A4Linear.from_linear(other.to_add_out, **kwargs)
def forward(
self,
hidden_states: torch.FloatTensor, # Image stream
encoder_hidden_states: torch.FloatTensor = None, # Text stream
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
):
return self.processor(
self,
hidden_states,
encoder_hidden_states,
encoder_hidden_states_mask,
attention_mask,
image_rotary_emb,
**kwargs,
)
def set_processor(self, processor: str):
if processor == "flashattn2":
self.processor = NunchakuQwenImageNaiveFA2Processor()
else:
raise ValueError(f"Processor {processor} is not supported")
class NunchakuQwenImageTransformerBlock(QwenImageTransformerBlock):
def __init__(self, other: QwenImageTransformerBlock, scale_shift: float = 1.0, **kwargs):
super(QwenImageTransformerBlock, self).__init__()
self.dim = other.dim
self.img_mod = other.img_mod
self.img_mod[1] = AWQW4A16Linear.from_linear(other.img_mod[1], **kwargs)
self.img_norm1 = other.img_norm1
self.attn = NunchakuQwenAttention(other.attn, **kwargs)
self.img_norm2 = other.img_norm2
self.img_mlp = NunchakuFeedForward(other.img_mlp, **kwargs)
# Text processing modules
self.txt_mod = other.txt_mod
self.txt_mod[1] = AWQW4A16Linear.from_linear(other.txt_mod[1], **kwargs)
self.txt_norm1 = other.txt_norm1
# Text doesn't need separate attention - it's handled by img_attn joint computation
self.txt_norm2 = other.txt_norm2
self.txt_mlp = NunchakuFeedForward(other.txt_mlp, **kwargs)
self.scale_shift = scale_shift
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply modulation to input tensor"""
shift, scale, gate = mod_params.chunk(3, dim=-1)
if self.scale_shift != 0:
scale.add_(self.scale_shift)
return x * scale.unsqueeze(1) + shift.unsqueeze(1), gate.unsqueeze(1)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Get modulation parameters for both streams
img_mod_params = self.img_mod(temb) # [B, 6*dim]
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
# nunchaku's mod_params is [B, 6*dim] instead of [B, dim*6]
img_mod_params = (
img_mod_params.view(img_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(img_mod_params.shape[0], -1)
)
txt_mod_params = (
txt_mod_params.view(txt_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(txt_mod_params.shape[0], -1)
)
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
# Split modulation parameters for norm1 and norm2
# img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
# txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
# Process image stream - norm1 + modulation
img_normed = self.img_norm1(hidden_states)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
# Process text stream - norm1 + modulation
txt_normed = self.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
# Use QwenAttnProcessor2_0 for joint attention computation
# This directly implements the DoubleStreamLayerMegatron logic:
# 1. Computes QKV for both streams
# 2. Applies QK normalization and RoPE
# 3. Concatenates and runs joint attention
# 4. Splits results back to separate streams
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=img_modulated, # Image stream (will be processed as "sample")
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
img_attn_output, txt_attn_output = attn_output
# Apply attention gates and add residual (like in Megatron)
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
# Process image stream - norm2 + MLP
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
img_mlp_output = self.img_mlp(img_modulated2)
hidden_states = hidden_states + img_gate2 * img_mlp_output
# Process text stream - norm2 + MLP
txt_normed2 = self.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
txt_mlp_output = self.txt_mlp(txt_modulated2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
# Clip to prevent overflow for fp16
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuModelLoaderMixin):
def _patch_model(self, **kwargs):
for i, block in enumerate(self.transformer_blocks):
self.transformer_blocks[i] = NunchakuQwenImageTransformerBlock(block, scale_shift=0, **kwargs)
return self
@classmethod
@utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
device = kwargs.get("device", "cpu")
offload = kwargs.get("offload", False)
if offload:
raise NotImplementedError("Offload is not supported for FluxTransformer2DModelV2")
torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
assert pretrained_model_name_or_path.is_file() or pretrained_model_name_or_path.name.endswith(
(".safetensors", ".sft")
), "Only safetensors are supported"
transformer, model_state_dict, metadata = cls._build_model(pretrained_model_name_or_path, **kwargs)
quantization_config = json.loads(metadata.get("quantization_config", "{}"))
config = json.loads(metadata.get("config", "{}"))
rank = quantization_config.get("rank", 32)
transformer = transformer.to(torch_dtype)
precision = get_precision()
if precision == "fp4":
precision = "nvfp4"
transformer._patch_model(precision=precision, rank=rank)
transformer = transformer.to_empty(device=device)
# need to re-init the pos_embed as to_empty does not work on it
transformer.pos_embed = QwenEmbedRope(
theta=10000, axes_dim=list(config.get("axes_dims_rope", [16, 56, 56])), scale_rope=True
)
state_dict = transformer.state_dict()
for k in state_dict.keys():
if k not in model_state_dict:
assert ".wtscale" in k or ".wcscales" in k
model_state_dict[k] = torch.ones_like(state_dict[k])
else:
assert state_dict[k].dtype == model_state_dict[k].dtype
transformer.load_state_dict(model_state_dict)
return transformer
from torch import nn
def fuse_linears(linears: list[nn.Linear]) -> nn.Linear:
assert len(linears) > 0
if len(linears) == 1:
return linears[0]
else:
assert all(linear.in_features == linears[0].in_features for linear in linears)
out_features = sum(linear.out_features for linear in linears)
bias = all(linear.bias is not None for linear in linears)
return nn.Linear(
linears[0].in_features,
out_features,
bias=bias,
dtype=linears[0].weight.dtype,
device=linears[0].weight.device,
)
import torch
from torch.nn import RMSNorm
from nunchaku.models.linear import SVDQW4A4Linear
from ..utils import ceil_divide
from .gemm import svdq_gemm_w4a4_cuda
def fused_gelu_mlp(x: torch.Tensor, fc1: SVDQW4A4Linear, fc2: SVDQW4A4Linear, pad_size: int = 256):
# a fused operator of fc1 and fc2 with gelu
batch_size, seq_len, channels = x.shape
x = x.view(batch_size * seq_len, channels)
quantized_x, ascales, lora_act = fc1.quantize(x)
batch_size_pad = ceil_divide(batch_size * seq_len, pad_size) * pad_size
qout_act = torch.empty(batch_size_pad, fc1.out_features // 2, dtype=torch.uint8, device=x.device)
if fc2.precision == "nvfp4":
qout_ascales = torch.empty(fc1.out_features // 16, batch_size_pad, dtype=torch.float8_e4m3fn, device=x.device)
else:
qout_ascales = torch.empty(fc1.out_features // 64, batch_size_pad, dtype=x.dtype, device=x.device)
qout_lora_act = torch.empty(batch_size_pad, fc2.proj_down.shape[1], dtype=torch.float32, device=x.device)
# for int4, we shift the activation after gelu to make it all positive to improve quality.
# if we pass the qout to this kernel, it will do the gelu fusion.
svdq_gemm_w4a4_cuda(
act=quantized_x,
wgt=fc1.qweight,
qout=qout_act,
ascales=ascales,
wscales=fc1.wscales,
oscales=qout_ascales,
lora_act_in=lora_act,
lora_up=fc1.proj_up,
lora_down=fc2.proj_down,
lora_act_out=qout_lora_act,
bias=fc1.bias,
smooth_factor=fc2.smooth_factor,
fp4=fc1.precision == "nvfp4",
alpha=fc1.wtscale,
wcscales=fc1.wcscales,
)
output = torch.empty(batch_size * seq_len, fc2.out_features, dtype=x.dtype, device=x.device)
output = fc2.forward_quant(qout_act, qout_ascales, qout_lora_act, output=output)
output = output.view(batch_size, seq_len, -1)
return output
def fused_qkv_norm_rottary(
x: torch.Tensor,
proj: SVDQW4A4Linear,
norm_q: RMSNorm,
norm_k: RMSNorm,
rotary_emb: torch.Tensor,
output: torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
attn_tokens: int = 0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert isinstance(norm_q, RMSNorm)
assert isinstance(norm_k, RMSNorm)
batch_size, seq_len, channels = x.shape
x = x.view(batch_size * seq_len, channels)
quantized_x, ascales, lora_act = proj.quantize(x)
if output is None:
output = torch.empty(quantized_x.shape[0], proj.out_features, dtype=x.dtype, device=x.device)
if isinstance(output, tuple):
assert len(output) == 3
output_q, output_k, output_v = output
svdq_gemm_w4a4_cuda(
act=quantized_x,
wgt=proj.qweight,
ascales=ascales,
wscales=proj.wscales,
lora_act_in=lora_act,
lora_up=proj.proj_up,
bias=proj.bias,
fp4=proj.precision == "nvfp4",
alpha=proj.wtscale,
wcscales=proj.wcscales,
norm_q=norm_q.weight,
norm_k=norm_k.weight,
rotary_emb=rotary_emb,
out_q=output_q,
out_k=output_k,
out_v=output_v,
attn_tokens=attn_tokens,
)
return output_q, output_k, output_v
else:
svdq_gemm_w4a4_cuda(
act=quantized_x,
wgt=proj.qweight,
out=output,
ascales=ascales,
wscales=proj.wscales,
lora_act_in=lora_act,
lora_up=proj.proj_up,
bias=proj.bias,
fp4=proj.precision == "nvfp4",
alpha=proj.wtscale,
wcscales=proj.wcscales,
norm_q=norm_q.weight,
norm_k=norm_k.weight,
rotary_emb=rotary_emb,
)
output = output.view(batch_size, seq_len, -1)
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment