"official/nlp/bert/model_training_utils.py" did not exist on "91c681af3ebc30c2de95da89bbe3e181638cade1"
Unverified Commit f86ad470 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

feat: pythonized model and QwenImage Support (#593)

* start refract the codebase

* update

* update

* start to implement ops

* add gemm

* write the docstrings

* define the w4a4 svdq linear

* update

* make the linter happy

* finished the SVDQW4A4Linear

* finished the SVDQW4A4Linear

* update

* update

* add a patcher to the model

* update

* add adanormsinglezero

* update

* update

* finished the naive implementation of nunchaku flux

* add ff

* finished the naive forward

* update

* svdq linear

* start debugging

* fix some issues

* successfully built the model

* update

* successfully load the model

* update

* update

* update

* try to making it runnable

* debugging

* debugging

* debugging

* add bias to awq linear

* run through

* fix the normalization

* update

* update

* update

* fix the attention

* fix the no fuse nvfp models

* update

* finished the fused ff

* make linter happy

* make linter happy

* make linter happy

* debugging the fp16 attn

* nunchaku fp16 is buggy

* finish the fp16 attn

* fp4 done

* fix the lora scales

* add a default value for alpha; need to debug int4

* fix input4

* update

* update

* ff does not work

* specialize the processors

* qwen transformer done. start debugging

* make linter happy

* add schnell v2 for metrics eval

* chore: schnellv2 eval

* update

* ff and attention correct

* need to check what happened to module

* fp4 done

* make linter happy

* update an example script

* reformat

* add an example script

* add the annoucement

* remove a misleading info

* ready to release
parent 954c7af9
"""
Python wrappers for Nunchaku's quantized GEMM operations.
"""
import math
import torch
from .._C import ops
def svdq_gemm_w4a4_cuda(
act: torch.Tensor,
wgt: torch.Tensor,
out: torch.Tensor | None = None,
qout: torch.Tensor | None = None,
ascales: torch.Tensor | None = None,
wscales: torch.Tensor | None = None,
oscales: torch.Tensor | None = None,
poolout: torch.Tensor | None = None,
lora_act_in: torch.Tensor | None = None,
lora_up: torch.Tensor | None = None,
lora_down: torch.Tensor | None = None,
lora_act_out: torch.Tensor | None = None,
norm_q: torch.Tensor | None = None,
norm_k: torch.Tensor | None = None,
rotary_emb: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
smooth_factor: torch.Tensor | None = None,
out_vk: torch.Tensor | None = None,
out_linearattn: torch.Tensor | None = None,
act_unsigned: bool = False,
lora_scales: list[float] | None = None,
fuse_silu: bool = False,
fp4: bool = False,
alpha: float | None = 1.0,
wcscales: torch.Tensor | None = None,
out_q: torch.Tensor | None = None,
out_k: torch.Tensor | None = None,
out_v: torch.Tensor | None = None,
attn_tokens: int = 0,
):
"""
This function wraps the high-performance CUDA kernel for SVDQuant W4A4 quantized GEMM.
Notation
--------
M : int
Batch size (number of input samples).
K : int
Number of input channels (feature dimension).
N : int
Number of output channels.
G : int
Number of groups. 64 for INT4 and 16 for NVFP4.
Parameters
----------
act : torch.Tensor
Input activation tensor. Packed shape (M, K // 2). Packed datatype: torch.int8
wgt : torch.Tensor
Quantized weight tensor. Packed shape (N, K // 2). Packed datatype: torch.int8
out : torch.Tensor or None
Output tensor for the linear layer. Shape (M, N). Datatype: torch.float16 or torch.bfloat16. If None, we will create a new tensor.
qout : torch.Tensor or None
Quantized output tensor for the next layer. Packed shape (M, N // 2). Packed datatype: torch.int8. If None, we will create a new tensor.
ascales : torch.Tensor
Activation scales tensor. Shape (K // G, M). Datatype: torch.float16 or torch.bfloat16 for INT4 and torch.float8_e4m3 for NVFP4.
wscales : torch.Tensor
Weight scales tensor. Shape (K // G, N). Datatype: torch.float16 or torch.bfloat16 for INT4 and torch.float8_e4m3 for NVFP4.
oscales : torch.Tensor or None
Output scales tensor. Shape (N // G, M). Datatype: torch.float16 or torch.bfloat16 for INT4 and torch.float8_e4m3 for NVFP4.
poolout : torch.Tensor or None
Not used for now. Just leave it as None.
lora_act_in : torch.Tensor
Low-rank down output tensor. Packed shape (M, R). Packed datatype: torch.float32.
lora_up : torch.Tensor
Low-rank up-projection weights. Packed shape (N, R). Packed datatype: torch.float16 or torch.bfloat16.
lora_down : torch.Tensor or None
Low-rank down-projection weights in the next layer. Packed shape (N, R). Packed datatype: torch.float16 or torch.bfloat16.
lora_act_out : torch.Tensor or None
Output tensor for low-rank down-projection in the next layer. Packed shape (M, R). Packed datatype: torch.float32.
norm_q : torch.Tensor or None
Query normalization tensor. Shape (HEAD_DIM,). Datatype: torch.float16 or torch.bfloat16.
norm_k : torch.Tensor or None
Key normalization tensor. Shape (HEAD_DIM,). Datatype: torch.float16 or torch.bfloat16.
rotary_emb : torch.Tensor or None
Rotary embedding tensor. Shape (M, HEAD_DIM // 2, 2, 2). Datatype: torch.float32. TODO: double check this.
bias : torch.Tensor or None
Bias tensor. Shape (N,). Datatype: torch.float16 or torch.bfloat16.
smooth_factor : torch.Tensor or None
Smoothing factor tensor for quantization in the next layer. Shape (N,). Datatype: torch.float16 or torch.bfloat16.
out_vk : torch.Tensor or None
Used only in SANA.
out_linearattn : torch.Tensor or None
Used only in SANA.
act_unsigned : bool, default=False
Whether activations are unsigned.
lora_scales : list of float, default=[]
Scaling factors for the low-rank branch.
fuse_silu : bool, default=False
Whether to fuse SiLU activation.
fp4 : bool, default=False
Whether to use 4-bit floating point quantization (NVFP4).
alpha : float, default=1.0
Per tensor scaling factor for NVFP4.
wcscales : torch.Tensor or None, default=None
Per channel scaling factors for NVFP4. Shape (N,). Datatype: torch.float8_e4m3.
out_q : torch.Tensor or None, default=None
Output tensor for quantized Q, used for Nunchaku attention. Packed shape (B, H, M, D). Datatype: torch.int8.
out_k : torch.Tensor or None, default=None
Output tensor for quantized K, used for Nunchaku attention. Packed shape (B, H, M, D). Datatype: torch.int8.
out_v : torch.Tensor or None, default=None
Output tensor for quantized V, used for Nunchaku attention. Packed shape (B, H, M, D). Datatype: torch.int8.
attn_tokens : int, default=0
Number of attention tokens.
Returns
-------
None
The results are written in-place to the provided output tensors.
"""
if lora_scales is None:
rank = lora_up.shape[1]
lora_scales = [1.0] * math.ceil(rank / 16)
if alpha is None:
alpha = 1.0
ops.gemm_w4a4(
act,
wgt,
out,
qout,
ascales,
wscales,
oscales,
poolout,
lora_act_in,
lora_up,
lora_down,
lora_act_out,
norm_q,
norm_k,
rotary_emb,
bias,
smooth_factor,
out_vk,
out_linearattn,
act_unsigned,
lora_scales,
fuse_silu,
fp4,
alpha,
wcscales,
out_q,
out_k,
out_v,
attn_tokens,
)
"""
Python wrappers for Nunchaku's quantized GEMV operations.
"""
import torch
from .._C import ops
def awq_gemv_w4a16_cuda(
in_feats: torch.Tensor,
kernel: torch.Tensor,
scaling_factors: torch.Tensor,
zeros: torch.Tensor,
m: int,
n: int,
k: int,
group_size: int = 64,
) -> torch.Tensor:
return ops.gemv_awq(in_feats, kernel, scaling_factors, zeros, m, n, k, group_size)
"""
Python wrappers for Nunchaku's quantization operations.
"""
import torch
from .._C import ops
from ..utils import ceil_divide
def svdq_quantize_w4a4_act_fuse_lora_cuda(
input: torch.Tensor,
output: torch.Tensor | None = None,
oscales: torch.Tensor | None = None,
lora_down: torch.Tensor | None = None,
lora_act_out: torch.Tensor | None = None,
smooth: torch.Tensor | None = None,
fuse_glu: bool = False,
fp4: bool = False,
pad_size: int = 256,
) -> torch.Tensor:
"""
This function wraps the high-performance CUDA kernel for SVDQuant W4A4 quantized GEMM.
Notation
--------
M : int
Batch size (number of input samples).
K : int
Number of input channels (feature dimension).
N : int
Number of output channels.
G : int
Number of groups. 64 for INT4 and 16 for NVFP4.
R : int
Rank of the low-rank branch.
"""
batch_size, channels = input.shape
rank = lora_down.shape[1]
batch_size_pad = ceil_divide(batch_size, pad_size) * pad_size
if output is None:
output = torch.empty(batch_size_pad, channels // 2, dtype=torch.uint8, device=input.device)
if oscales is None:
if fp4:
assert channels % 16 == 0
oscales = torch.empty(channels // 16, batch_size_pad, dtype=torch.float8_e4m3fn, device=input.device)
else:
assert channels % 64 == 0
oscales = torch.empty(channels // 64, batch_size_pad, dtype=input.dtype, device=input.device)
if lora_act_out is None:
lora_act_out = torch.empty(batch_size_pad, rank, dtype=torch.float32, device=input.device)
ops.quantize_w4a4_act_fuse_lora(input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu, fp4)
return output, oscales, lora_act_out
import gc
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from diffusers.pipelines.qwenimage.pipeline_qwenimage import (
QwenImagePipeline,
QwenImagePipelineOutput,
calculate_shift,
retrieve_timesteps,
)
class NunchakuQwenImagePipeline(QwenImagePipeline):
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
true_cfg_scale: float = 4.0,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 1.0,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = torch.device("cuda")
self.text_encoder.to(device)
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
if do_true_cfg:
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
if self.attention_kwargs is None:
self._attention_kwargs = {}
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
self.text_encoder.to("cpu")
gc.collect()
torch.cuda.empty_cache()
self.transformer.to(device)
# 6. Denoising loop
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=negative_txt_seq_lens,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
self.transformer.to("cpu")
gc.collect()
torch.cuda.empty_cache()
self.vae.to(device)
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
self.vae.to("cpu")
gc.collect()
torch.cuda.empty_cache()
if not return_dict:
return (image,)
return QwenImagePipelineOutput(images=image)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment