Commit 37c494a7 authored by Zhekai Zhang's avatar Zhekai Zhang
Browse files

Initial release

parents
"""
Utilities adapted from
* https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_bnb_4bit.py
* https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/bitsandbytes.py
"""
import torch
import bitsandbytes as bnb
from transformers.quantizers.quantizers_utils import get_module_from_name
import torch.nn as nn
from accelerate import init_empty_weights
def _replace_with_bnb_linear(
model,
method="nf4",
has_been_replaced=False,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
for name, module in model.named_children():
if isinstance(module, nn.Linear):
with init_empty_weights():
in_features = module.in_features
out_features = module.out_features
if method == "llm_int8":
model._modules[name] = bnb.nn.Linear8bitLt(
in_features,
out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=6.0,
)
has_been_replaced = True
else:
model._modules[name] = bnb.nn.Linear4bit(
in_features,
out_features,
module.bias is not None,
compute_dtype=torch.bfloat16,
compress_statistics=False,
quant_type="nf4",
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_bnb_linear(
module,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
return model, has_been_replaced
def check_quantized_param(
model,
param_name: str,
) -> bool:
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
# Add here check for loaded components' dtypes once serialization is implemented
return True
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
# bias could be loaded by regular set_module_tensor_to_device() from accelerate,
# but it would wrongly use uninitialized weight there.
return True
else:
return False
def create_quantized_param(
model,
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict=None,
unexpected_keys=None,
pre_quantized=False
):
module, tensor_name = get_module_from_name(model, param_name)
if tensor_name not in module._parameters:
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
old_value = getattr(module, tensor_name)
if tensor_name == "bias":
if param_value is None:
new_value = old_value.to(target_device)
else:
new_value = param_value.to(target_device)
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
module._parameters[tensor_name] = new_value
return
if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
raise ValueError("this function only loads `Linear4bit components`")
if (
old_value.device == torch.device("meta")
and target_device not in ["meta", torch.device("meta")]
and param_value is None
):
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
if pre_quantized:
if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (
param_name + ".quant_state.bitsandbytes__nf4" not in state_dict
):
raise ValueError(
f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components."
)
quantized_stats = {}
for k, v in state_dict.items():
# `startswith` to counter for edge cases where `param_name`
# substring can be present in multiple places in the `state_dict`
if param_name + "." in k and k.startswith(param_name):
quantized_stats[k] = v
if unexpected_keys is not None and k in unexpected_keys:
unexpected_keys.remove(k)
new_value = bnb.nn.Params4bit.from_prequantized(
data=param_value,
quantized_stats=quantized_stats,
requires_grad=False,
device=target_device,
)
else:
new_value = param_value.to("cpu")
kwargs = old_value.__dict__
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
print(f"{param_name}: new_value.quant_type={new_value.quant_type} quant_state={new_value.quant_state} storage={new_value.quant_storage} blocksize={new_value.blocksize}")
state = new_value.quant_state
print(f" -- state.code={state.code} dtype={state.dtype} blocksize={state.blocksize}")
module._parameters[tensor_name] = new_value
# generate.py
# from huggingface_hub import hf_hub_download
# from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
# from accelerate import init_empty_weights
# from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
# from convert_nf4_flux import _replace_with_bnb_linear, create_quantized_param, check_quantized_param
# from diffusers import FluxTransformer2DModel, FluxPipeline
# import safetensors.torch
# import gc
# import torch
# dtype = torch.bfloat16
# ckpt_path = hf_hub_download("black-forest-labs/flux.1-dev", filename="flux1-dev.safetensors")
# original_state_dict = safetensors.torch.load_file(ckpt_path)
# converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)
# del original_state_dict
# gc.collect()
# with init_empty_weights():
# config = FluxTransformer2DModel.load_config("black-forest-labs/flux.1-dev", subfolder="transformer")
# model = FluxTransformer2DModel.from_config(config).to(dtype)
# _replace_with_bnb_linear(model, "nf4")
# for param_name, param in converted_state_dict.items():
# param = param.to(dtype)
# if not check_quantized_param(model, param_name):
# set_module_tensor_to_device(model, param_name, device=0, value=param)
# else:
# create_quantized_param(model, param, param_name, target_device=0)
# del converted_state_dict
# gc.collect()
# print(compute_module_sizes(model)[""] / 1024 / 1204)
# pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype)
# pipe.enable_model_cpu_offload()
# prompt = "A mystic cat with a sign that says hello world!"
# image = pipe(prompt, guidance_scale=3.5, num_inference_steps=50, generator=torch.manual_seed(0)).images[0]
# image.save("flux-nf4-dev.png")
# model.push_to_hub("sayakpaul/flux.1-dev-nf4")
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import FluxPipeline, FluxTransformer2DModel
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from diffusers.models.attention import FeedForward
import safetensors.torch
from dataclasses import dataclass
from typing import Optional
import qmodule
TensorDict = dict[str, torch.Tensor]
@dataclass
class DeepCompressorModel:
model: dict[str, torch.Tensor]
smooth: dict[str, torch.Tensor]
branch: dict[str, dict[str, torch.Tensor]]
lora: dict[str, torch.Tensor]
def merge_dict(old: dict, new: dict, prefix: str):
for key, value in new.items():
newkey = prefix + key
assert not newkey in old
old[newkey] = value
def group_scale(weight: torch.Tensor, num_bits: int, group_size: int) -> torch.Tensor:
oc, ic = weight.shape
assert ic % group_size == 0
maxvalues = weight.reshape(oc, ic // group_size, group_size).abs().max(dim=-1).values
qmax = 2 ** (num_bits - 1) - 1
scale = (maxvalues.float() / qmax)
scale[scale == 0] = 1
return scale.to(weight.dtype)
def ceil_div(x, y):
return (x + y - 1) // y
def quantize_4bit(weight: torch.Tensor, wscales: torch.Tensor) -> torch.Tensor:
oc, ic = weight.shape
group_size = ic // wscales.shape[-1]
# print(group_size)
# print(weight.shape)
# print(wscales.shape)
# print(f"wscales={wscales}")
qweight = weight.reshape(oc, ic // group_size, group_size).to(dtype=torch.float32) / wscales[..., None]
# print(f"qweight={qweight}")
qweight = qweight.reshape(oc, ic // 8, 8).round().clamp(-8, 7).to(dtype=torch.int32)
qweight = qweight.bitwise_and_(0xf)
shift = torch.arange(0, 32, 4, dtype=torch.int32)
qweight = qweight.bitwise_left_shift_(shift)
qweight = qweight.sum(dim=-1, dtype=torch.int32)
return qweight
def dump_linear_awq(weight: torch.Tensor, bias: torch.Tensor) -> dict[str, torch.Tensor]:
tensors = qmodule.dump_linear_awq(weight, bias, w_bit=4, group_size=64, zero_point=False)
tensors["qweight"] = tensors["qweight"].view(dtype=torch.int32)
return tensors
def pack_wscales(wscales: torch.Tensor) -> torch.Tensor:
N, groups = wscales.shape
assert wscales.dtype.itemsize == 2
BLOCK_N = 128
WSCALES_PACK_SIZE = 4
WSCALES_NUM_PACKS = 1
WSCALES_VALID_LANES = 32
wscales = wscales.reshape(ceil_div(N, BLOCK_N), BLOCK_N, groups)
wscales = wscales.permute(0, 2, 1) # [..., BLOCK_N]
wscales = wscales.reshape(*wscales.shape[0:2], WSCALES_NUM_PACKS, WSCALES_VALID_LANES // 4, WSCALES_PACK_SIZE // 2, 4, 2)
wscales = wscales.permute(0, 1, 2, 3, 5, 4, 6)
wscales = wscales.contiguous()
wscales = wscales.view(groups, N)
return wscales
# print(pack_wscales(torch.arange(0, 256, dtype=torch.int16)[..., None]))
# exit(0)
def pack_qweight(qweight: torch.Tensor) -> torch.Tensor:
N, K = qweight.shape
K *= 8
assert qweight.dtype.itemsize == 4
BLOCK_N = 128
WARP_K = 64
WARP_N_TILES = BLOCK_N // 16
qweight = qweight.reshape(ceil_div(N, BLOCK_N), WARP_N_TILES, 16, ceil_div(K, WARP_K), WARP_K // 8)
qweight = qweight.permute(0, 3, 1, 2, 4) # [N / BLOCK_N, K / WARP_K, WARP_N_TILES, (INSN_N) => 16 , WARP_K / 8 => 8]
# print(qweight.shape)
# https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16864-b-1
assert qweight.shape[3:] == (16, 8)
qweight = qweight.reshape(*qweight.shape[0:3], 2, 8, 2, 4)
qweight = qweight.permute(0, 1, 2, 4, 6, 3, 5)
assert qweight.shape[3:] == (8, 4, 2, 2)
print(qweight.dtype)
print(qweight.shape)
qweight = qweight.contiguous()
qweight = qweight.view(dtype=torch.int8) # assume little-endian
print(qweight.shape)
qweight = qweight.view(N, K // 2)
return qweight
def pack_lora(weight: torch.Tensor, is_lora_down: bool) -> torch.Tensor:
N, R = weight.shape
assert N % 16 == 0
assert R % 16 == 0
assert weight.dtype.itemsize == 2
weight = weight.reshape(N // 16, 16, R // 16, 16)
weight = weight.permute(0, 2, 1, 3)
if is_lora_down:
weight = weight.transpose(-1, -2)
# https://docs.nvidia.com/cuda/parallel-thread-execution/#mma-16816-b-f16
assert weight.shape[2:] == (16, 16)
weight = weight.reshape(*weight.shape[0:2], 2, 8, 2, 4, 2)
weight = weight.permute(0, 1, 3, 5, 2, 4, 6)
weight = weight.contiguous()
weight = weight.view(N, R)
return weight
def dump_linear_w4a4(
weight: torch.Tensor,
bias: torch.Tensor | None = None,
smooth: torch.Tensor | None = None,
lora_down: torch.Tensor | None = None,
lora_up: torch.Tensor | None = None) -> dict[str, torch.Tensor]:
print(f"dump_linear_w4a4: weight.shape={weight.shape}")
tensors = {}
group_size = 64
oc, ic = weight.shape
N, K = oc, ic
# LORA_RANK = 32
wscales = group_scale(weight, num_bits=4, group_size=group_size)
qweight = quantize_4bit(weight, wscales) # [N, K / 8]
assert qweight.shape == (N, K // 8)
qweight = pack_qweight(qweight)
wscales = pack_wscales(wscales)
if bias is None:
bias = torch.zeros([oc], dtype=weight.dtype)
bias = pack_wscales(bias[..., None])
assert bias.shape == (1, oc)
bias = bias[0]
if smooth is None:
smooth = torch.ones([ic], dtype=weight.dtype)
if smooth.dtype != weight.dtype:
print(f"Convert smooth dtype from {smooth.dtype} to {weight.dtype}")
smooth = smooth.to(weight.dtype)
smooth = pack_wscales(smooth[..., None])
assert smooth.shape == (1, ic)
smooth = smooth[0]
# if lora_down is None:
# lora_down = torch.zeros([LORA_RANK, ic], dtype=weight.dtype)
# if lora_up is None:
# lora_up = torch.zeros([oc, LORA_RANK], dtype=weight.dtype)
if not lora_down is None:
lora_down = pack_lora(lora_down.transpose(0, 1), is_lora_down=True)
if not lora_up is None:
lora_up = pack_lora(lora_up, is_lora_down=False)
tensors["qweight"] = qweight
tensors["wscales"] = wscales
tensors["bias"] = bias
if not lora_down is None:
tensors["lora_down"] = lora_down
if not lora_up is None:
tensors["lora_up"] = lora_up
tensors["smooth"] = smooth
return tensors
def dump_linear_adanorm_single(weight: torch.Tensor, bias: torch.Tensor) -> TensorDict:
oc, ic = weight.shape
assert oc % 3 == 0
# shift_msa, scale_msa, gate_msa
weight = weight.reshape(3, oc // 3, ic).transpose(0, 1).reshape(oc, ic).contiguous()
bias = bias.reshape(3, oc // 3).transpose(0, 1)
# [oc // 3, 3]
bias = bias + torch.tensor([0, 1, 0], dtype=bias.dtype)
bias = bias.reshape(oc).contiguous()
return dump_linear_awq(weight, bias)
def dump_linear_adanorm_zero(weight: torch.Tensor, bias: torch.Tensor) -> dict[str, torch.Tensor]:
oc, ic = weight.shape
assert oc % 6 == 0
# shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp
weight = weight.reshape(6, oc // 6, ic).transpose(0, 1).reshape(oc, ic).contiguous()
bias = bias.reshape(6, oc // 6).transpose(0, 1)
# [oc // 6, 6]
bias = bias + torch.tensor([0, 1, 0, 0, 1, 0], dtype=bias.dtype)
bias = bias.reshape(oc).contiguous()
return dump_linear_awq(weight, bias)
def dump_linear_layer_w4a4(layer: torch.nn.Linear):
return dump_linear_w4a4(layer.weight.detach(), layer.bias.detach())
def dump_linear_layer_adanorm_single(block: torch.nn.Linear) -> TensorDict:
return dump_linear_adanorm_single(block.weight.detach(), block.bias.detach())
def dump_linear_layer_adanorm_zero(block: torch.nn.Linear) -> dict[str, torch.Tensor]:
return dump_linear_adanorm_zero(block.weight.detach(), block.bias.detach())
def dump_qkv_proj(q: torch.nn.Linear, k: torch.nn.Linear, v: torch.nn.Linear) -> TensorDict:
qkv = [q, k, v]
qkv_weight = torch.cat([linear.weight.detach() for linear in qkv], dim=0)
qkv_bias = torch.cat([linear.bias.detach() for linear in qkv], dim=0)
print(qkv_weight.shape)
print(qkv_bias.shape)
return dump_linear_w4a4(qkv_weight, qkv_bias)
def dump_single_transformer(block: FluxSingleTransformerBlock) -> dict[str, torch.Tensor]:
tensors = {}
merge_dict(tensors, dump_linear_layer_adanorm_single(block.norm.linear), "norm.linear.")
merge_dict(tensors, dump_qkv_proj(block.attn.to_q, block.attn.to_k, block.attn.to_v), "qkv_proj.")
tensors["norm_q.weight"] = block.attn.norm_q.weight.detach()
tensors["norm_k.weight"] = block.attn.norm_k.weight.detach()
merge_dict(tensors, dump_linear_layer_w4a4(block.proj_mlp), "mlp_fc1.")
merge_dict(tensors, dump_linear_w4a4(
block.proj_out.weight.detach()[:, 0:block.attn.out_dim],
bias=None
), "out_proj.")
merge_dict(tensors, dump_linear_w4a4(
block.proj_out.weight.detach()[:, block.attn.out_dim:],
bias=block.proj_out.bias.detach()
# block.proj_out.weight.detach()
), "mlp_fc2.")
# print(dict(block.named_parameters()).keys())
return tensors
def dump_transformer(block: FluxTransformerBlock) -> TensorDict:
tensors = {}
merge_dict(tensors, dump_linear_layer_adanorm_zero(block.norm1.linear), "norm1.linear.")
merge_dict(tensors, dump_linear_layer_adanorm_zero(block.norm1_context.linear), "norm1_context.linear.")
merge_dict(tensors, dump_qkv_proj(block.attn.to_q, block.attn.to_k, block.attn.to_v), "qkv_proj.")
merge_dict(tensors, dump_qkv_proj(block.attn.add_q_proj, block.attn.add_k_proj, block.attn.add_v_proj), "qkv_proj_context.")
tensors["norm_q.weight"] = block.attn.norm_q.weight.detach()
tensors["norm_k.weight"] = block.attn.norm_k.weight.detach()
tensors["norm_added_q.weight"] = block.attn.norm_added_q.weight.detach()
tensors["norm_added_k.weight"] = block.attn.norm_added_k.weight.detach()
merge_dict(tensors, dump_linear_layer_w4a4(block.attn.to_out[0]), "out_proj.")
merge_dict(tensors, dump_linear_layer_w4a4(block.attn.to_add_out), "out_proj_context.")
# no params for norm2
merge_dict(tensors, dump_linear_layer_w4a4(block.ff.net[0].proj), "mlp_fc1.")
merge_dict(tensors, dump_linear_layer_w4a4(block.ff.net[2]), "mlp_fc2.")
merge_dict(tensors, dump_linear_layer_w4a4(block.ff_context.net[0].proj), "mlp_context_fc1.")
merge_dict(tensors, dump_linear_layer_w4a4(block.ff_context.net[2]), "mlp_context_fc2.")
return tensors
def recip_smooth(smooth: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
if smooth is None:
return None
return (1. / smooth.to(torch.float64)).to(smooth.dtype)
def unsmooth(weight: torch.Tensor, smooth: Optional[torch.Tensor]) -> torch.Tensor:
if smooth is None:
return weight
assert smooth.ndim == 1
assert weight.ndim == 2
assert weight.shape[1] == smooth.shape[0]
return (weight.to(torch.float64) / smooth[None, ...].to(torch.float64)).to(weight.dtype)
def dump_qkv_proj_svdq(
qmodel: DeepCompressorModel,
key_qkv: tuple[str, str, str],
key_smooth: str,
) -> TensorDict:
qkv_weight = torch.cat([qmodel.model[f"{k}.weight"] for k in key_qkv], dim=0)
qkv_bias = torch.cat([qmodel.model[f"{k}.bias"] for k in key_qkv], dim=0)
print(qkv_weight.shape)
print(qkv_bias.shape)
smooth = qmodel.smooth[key_smooth].to(qkv_weight.dtype).float()
return dump_linear_w4a4(
qkv_weight, qkv_bias,
smooth=smooth, # recip_smooth(smooth),
lora_down=unsmooth(qmodel.branch[key_smooth]["a.weight"], smooth),
lora_up=qmodel.branch[key_smooth]["b.weight"]
)
def dump_linear_layer_w4a4_svdq(
qmodel: DeepCompressorModel,
key: str,
alt_smooth: bool = False,
smooth: Optional[torch.Tensor] = None,
alt_bias: bool = False,
bias: Optional[torch.Tensor] = None,
bias_fuse_shift: float = 0
) -> TensorDict:
if not alt_smooth:
smooth = qmodel.smooth[key]
if not alt_bias:
bias = qmodel.model[f"{key}.bias"]
weight = qmodel.model[f"{key}.weight"]
if smooth is not None:
smooth = smooth.to(weight.dtype).float()
if bias_fuse_shift != 0:
oc, ic = weight.shape
shift = torch.ones([ic], dtype=weight.dtype, device=weight.device) * bias_fuse_shift
if smooth is not None:
shift = (shift / smooth).to(weight.dtype)
delta = F.linear(shift, weight)
if bias is None:
bias = torch.zeros([oc], dtype=weight.dtype, device=weight.device)
bias -= delta
return dump_linear_w4a4(
weight=weight,
bias=bias,
smooth=smooth, # recip_smooth(smooth),
lora_down=unsmooth(qmodel.branch[key]["a.weight"], smooth),
lora_up=qmodel.branch[key]["b.weight"]
)
GELU_SHIFT_VALUE = 0.171875
def dump_transformer_svdq(
qmodel: DeepCompressorModel,
layer_id: int,
original_net: Optional[FluxTransformer2DModel],
use_original_adanorm: bool = False,
num_svdq_joint: int = 19,
shift_gelu: bool = True,
**kwargs,
) -> TensorDict:
tensors = {}
original_block: FluxTransformerBlock = original_net.transformer_blocks[layer_id]
if layer_id >= num_svdq_joint:
return dump_transformer(original_block)
prefix = f"transformer_blocks.{layer_id}"
def model(key: str):
return qmodel.model[f"{prefix}.{key}"]
def linear(key: str, **kwargs):
return dump_linear_layer_w4a4_svdq(qmodel, f"{prefix}.{key}", **kwargs)
if use_original_adanorm:
merge_dict(tensors, dump_linear_layer_adanorm_zero(original_block.norm1.linear), "norm1.linear.")
merge_dict(tensors, dump_linear_layer_adanorm_zero(original_block.norm1_context.linear), "norm1_context.linear.")
else:
merge_dict(tensors, dump_linear_adanorm_zero(model("norm1.linear.weight"), model("norm1.linear.bias")), "norm1.linear.")
merge_dict(tensors, dump_linear_adanorm_zero(model("norm1_context.linear.weight"), model("norm1_context.linear.bias")), "norm1_context.linear.")
merge_dict(tensors, dump_qkv_proj_svdq(
qmodel,
(f"{prefix}.attn.to_q", f"{prefix}.attn.to_k", f"{prefix}.attn.to_v"),
f"{prefix}.attn.to_q"
), "qkv_proj.")
merge_dict(tensors, dump_qkv_proj_svdq(
qmodel,
(f"{prefix}.attn.add_q_proj", f"{prefix}.attn.add_k_proj", f"{prefix}.attn.add_v_proj"),
f"{prefix}.attn.add_k_proj"
), "qkv_proj_context.")
tensors["norm_q.weight"] = model("attn.norm_q.weight")
tensors["norm_k.weight"] = model("attn.norm_k.weight")
tensors["norm_added_q.weight"] = model("attn.norm_added_q.weight")
tensors["norm_added_k.weight"] = model("attn.norm_added_k.weight")
# DONE GELU should be before lora up, also +shift
# DONE smooth factor 1/smooth
# DONE smooth fuse to lora down
merge_dict(tensors, linear("attn.to_out.0", alt_smooth=True, smooth=None), "out_proj.")
merge_dict(tensors, linear("attn.to_add_out", alt_smooth=True, smooth=None), "out_proj_context.")
merge_dict(tensors, linear("ff.net.0.proj"), "mlp_fc1.")
merge_dict(tensors, linear("ff.net.2.linear", alt_bias=True, bias=original_block.ff.net[2].bias, bias_fuse_shift=GELU_SHIFT_VALUE if shift_gelu else 0), "mlp_fc2.")
merge_dict(tensors, linear("ff_context.net.0.proj"), "mlp_context_fc1.")
merge_dict(tensors, linear("ff_context.net.2.linear", alt_bias=True, bias=original_block.ff_context.net[2].bias, bias_fuse_shift=GELU_SHIFT_VALUE if shift_gelu else 0), "mlp_context_fc2.")
return tensors
def dump_single_transformer_svdq(
qmodel: DeepCompressorModel,
layer_id: int,
original_net: Optional[FluxTransformer2DModel],
use_original_adanorm: bool = False,
num_svdq_single: int = 38,
shift_gelu: bool = True,
**kwargs
) -> TensorDict:
tensors = {}
original_block: FluxSingleTransformerBlock = original_net.single_transformer_blocks[layer_id]
if layer_id >= num_svdq_single:
return dump_single_transformer(original_block)
prefix = f"single_transformer_blocks.{layer_id}"
def model(key: str):
return qmodel.model[f"{prefix}.{key}"]
def linear(key: str, **kwargs):
return dump_linear_layer_w4a4_svdq(qmodel, f"{prefix}.{key}", **kwargs)
if use_original_adanorm:
merge_dict(tensors, dump_linear_layer_adanorm_single(original_block.norm.linear), "norm.linear.")
else:
merge_dict(tensors, dump_linear_adanorm_single(model("norm.linear.weight"), model("norm.linear.bias")), "norm.linear.")
merge_dict(tensors, dump_qkv_proj_svdq(
qmodel,
(f"{prefix}.attn.to_q", f"{prefix}.attn.to_k", f"{prefix}.attn.to_v"),
f"{prefix}.attn.to_q"
), "qkv_proj.")
tensors["norm_q.weight"] = model("attn.norm_q.weight")
tensors["norm_k.weight"] = model("attn.norm_k.weight")
merge_dict(tensors, linear("proj_mlp", alt_smooth=True, smooth=qmodel.smooth[f"{prefix}.attn.to_q"]), "mlp_fc1.")
merge_dict(tensors, linear("proj_out.linears.0", alt_smooth=True, smooth=None, alt_bias=True, bias=None), "out_proj.")
merge_dict(tensors, linear("proj_out.linears.1.linear", alt_bias=True, bias=original_block.proj_out.bias, bias_fuse_shift=GELU_SHIFT_VALUE if shift_gelu else 0), "mlp_fc2.")
return tensors
@torch.inference_mode()
def dump_flux(net: FluxTransformer2DModel) -> TensorDict:
tensors = {}
for i in range(len(net.transformer_blocks)):
merge_dict(tensors, dump_transformer(net.transformer_blocks[i]), f"transformer_blocks.{i}.")
for i in range(len(net.single_transformer_blocks)):
merge_dict(tensors, dump_single_transformer(net.single_transformer_blocks[i]), f"single_transformer_blocks.{i}.")
return tensors
@torch.inference_mode()
def dump_flux_svdq(qmodel: DeepCompressorModel, **kwargs) -> TensorDict:
tensors = {}
for i in range(19):
merge_dict(tensors, dump_transformer_svdq(qmodel, i, **kwargs), f"transformer_blocks.{i}.")
for i in range(38):
merge_dict(tensors, dump_single_transformer_svdq(qmodel, i, **kwargs), f"single_transformer_blocks.{i}.")
return tensors
def load_svdq(path: str) -> DeepCompressorModel:
return DeepCompressorModel(
model=torch.load(f"{path}/model.pt", map_location="cpu"),
smooth=torch.load(f"{path}/smooth.pt", map_location="cpu"),
branch=torch.load(f"{path}/branch.pt", map_location="cpu"),
lora={}
)
if __name__ == "__main__":
use_svdq = True
use_original_adanorm = True
shift_gelu = True
dev = False
num_svdq_joint = 19
num_svdq_single = 38
if not use_svdq:
pipe = FluxPipeline.from_pretrained(
f"black-forest-labs/FLUX.1-{'dev' if dev else 'schnell'}", torch_dtype=torch.bfloat16)
net: FluxTransformer2DModel = pipe.transformer
print(net)
tensors = dump_flux(net)
dtype = pipe.dtype
else:
pipe = FluxPipeline.from_pretrained(
f"black-forest-labs/FLUX.1-{'dev' if dev else 'schnell'}", torch_dtype=torch.bfloat16)
net: FluxTransformer2DModel = pipe.transformer
tensors = dump_flux_svdq(
load_svdq(path="model-dev" if dev else "model-schnell"),
original_net=net,
use_original_adanorm=use_original_adanorm,
num_svdq_joint=num_svdq_joint,
num_svdq_single=num_svdq_single,
shift_gelu=shift_gelu,
)
dtype = torch.bfloat16
for k, v in tensors.items():
assert not v.isnan().any()
assert not v.isinf().any()
safetensors.torch.save_file(
tensors,
f"/tmp/flux{f'-dev' if dev else ''}{f'-svdq-{num_svdq_joint}-{num_svdq_single}' if use_svdq else ''}-divsmooth{'-shift' if shift_gelu else ''}{'-ada' if use_original_adanorm else ''}-{'bf16' if dtype == torch.bfloat16 else 'fp16'}.safetensors")
# tensors = dump_single_transformer(net.single_transformer_blocks[0])
# print(tensors)
# print(dump_transformer(net.transformer_blocks[0]))
# print(dict(net.named_parameters()))
\ No newline at end of file
import torch
import safetensors
import torch.nn.functional as F
from dump_flux import DeepCompressorModel, TensorDict, pack_wscales, pack_lora, merge_dict, unsmooth
from typing import Optional
Lora = tuple[torch.Tensor, torch.Tensor]
def load_svdq_lora(path: str, lora_path: str) -> DeepCompressorModel:
result = DeepCompressorModel(
model=torch.load(f"{path}/model.pt", map_location="cpu"),
smooth=torch.load(f"{path}/smooth.pt", map_location="cpu"),
branch=torch.load(f"{path}/branch.pt", map_location="cpu"),
lora={}
)
with safetensors.safe_open(lora_path, framework="pt", device="cpu") as f:
for k in f.keys():
prefix = "transformer."
if k.startswith(prefix):
result.lora[k.removeprefix(prefix)] = f.get_tensor(k)
dtype = next(iter(result.branch.values()))["a.weight"].dtype
for k, v in result.lora.items():
if v.dtype != dtype:
print(f"Convert lora weight {k} from {v.dtype} to {dtype}")
result.lora[k] = v.to(dtype)
# for k, v in result.lora.items():
# v.fill_(0)
return result
# q/k/v [3072, ...] -> qkv [3072 * 3, ...]
def extend_qkv(input: torch.Tensor, id: int) -> torch.Tensor:
oc, ic = input.shape
tmp = torch.zeros([oc * 3, ic], dtype=input.dtype, device=input.device)
tmp[id*oc:(id+1)*oc, ...] = input
return tmp
def merge_lora(inputs: list[Lora]) -> Optional[Lora]:
if len(inputs) == 0:
return None
lora_downs = [x[0] for x in inputs]
lora_ups = [x[1] for x in inputs]
lora_down = torch.cat(lora_downs, dim=0)
lora_up = torch.cat(lora_ups, dim=1)
return (lora_down, lora_up)
def merge_lora_qkv(inputs: list[Lora]) -> list[Lora]:
if len(inputs) == 0:
return []
for x in inputs:
if not x[0].equal(inputs[0][0]):
return inputs
lora_down = inputs[0][0]
lora_ups = [x[1] for x in inputs]
lora_up = torch.sum(torch.stack(lora_ups), dim=0).to(lora_down.dtype)
return [(lora_down, lora_up)]
def dump_lora(lora_down: Optional[torch.Tensor], lora_up: Optional[torch.Tensor]) -> TensorDict:
if lora_down is None:
return {}
rank, ic = lora_down.shape
oc = lora_up.shape[0]
assert lora_up.shape == (oc, rank)
if rank % 16 != 0:
rank_pad = (rank + 16 - 1) // 16 * 16
tmp_down = torch.zeros([rank_pad, ic], dtype=lora_down.dtype, device=lora_down.device)
tmp_up = torch.zeros([oc, rank_pad], dtype=lora_down.dtype, device=lora_down.device)
tmp_down[:rank, ...] = lora_down
tmp_up[..., :rank] = lora_up
lora_down = tmp_down
lora_up = tmp_up
print(f"Pad lora rank from {rank} to {rank_pad}")
lora_down = pack_lora(lora_down.transpose(0, 1), is_lora_down=True)
lora_up = pack_lora(lora_up, is_lora_down=False)
tensors = {}
tensors["lora_down"] = lora_down
tensors["lora_up"] = lora_up
return tensors
def get_original_lora(qmodel: DeepCompressorModel, key_branch: str, key_smooth: Optional[str]) -> Lora:
dtype = qmodel.branch[key_branch]["a.weight"].dtype
smooth = qmodel.smooth[key_smooth].to(dtype).float() if key_smooth else None
return (
unsmooth(qmodel.branch[key_branch]["a.weight"], smooth),
qmodel.branch[key_branch]["b.weight"]
)
def dump_linear_lora(
qmodel: DeepCompressorModel,
key_lora: str,
key_branch: str,
key_smooth: str,
shift_bias: bool = False,
key_bias: Optional[str] = None,
range_ic: slice = slice(None, None, None)) -> TensorDict:
lora_original = get_original_lora(qmodel, key_branch, key_smooth)
if f"{key_lora}.lora_A.weight" in qmodel.lora:
# lora_down = qmodel.lora[f"{key}.lora_A.weight"][..., range_ic]
# lora_up = qmodel.lora[f"{key}.lora_B.weight"]
lora_new = (
qmodel.lora[f"{key_lora}.lora_A.weight"][..., range_ic],
qmodel.lora[f"{key_lora}.lora_B.weight"]
)
lora_down, lora_up = merge_lora([lora_original, lora_new])
rank, ic = lora_down.shape
oc = lora_up.shape[0]
assert lora_up.shape == (oc, rank)
print(f"linear at {key_lora} has rank {rank}")
tensors = dump_lora(lora_down, lora_up)
if shift_bias and False: # no longer need shift bias
if key_bias is None:
key_bias = f"{key_branch}.bias"
if key_bias in qmodel.model:
bias = qmodel.model[key_bias]
print(f"linear at {key_lora} apply shift_bias from original bias at {key_bias}")
else:
bias = torch.zeros([oc], dtype=lora_up.dtype, device=lora_up.device)
print(f"linear at {key_lora} apply shift_bias from empty original bias")
shift = torch.empty([ic], dtype=lora_down.dtype, device=lora_down.device)
shift = shift.fill_(0.171875)
delta = F.linear(F.linear(shift, lora_new[0]), lora_new[1])
print(f"shift_bias delta = {delta}")
bias -= delta
tensors["bias"] = pack_wscales(bias[..., None])[0]
return tensors
else:
print(f"linear at {key_lora} use original lora")
return dump_lora(*lora_original)
def dump_qkv_proj_svdq_lora(
qmodel: DeepCompressorModel,
key_qkv: tuple[str, str, str],
key_smooth: str,
key_smooth_out: str
) -> TensorDict:
dtype = qmodel.branch[key_smooth]["a.weight"].dtype
smooth_out = qmodel.smooth[key_smooth_out].to(dtype).float()
lora_original = get_original_lora(qmodel, key_smooth, key_smooth)
loras = []
for i in range(3):
key = key_qkv[i]
if f"{key}.lora_A.weight" in qmodel.lora:
lora_down = qmodel.lora[f"{key}.lora_A.weight"]
lora_up = qmodel.lora[f"{key}.lora_B.weight"]
if i == 2:
lora_up = (lora_up / smooth_out[..., None]).to(lora_up.dtype)
loras.append((lora_down, extend_qkv(lora_up, i)))
# print(loras)
lora_down, lora_up = merge_lora([lora_original, *merge_lora_qkv(loras)])
print(f"qkv_proj at {key_smooth} has rank {lora_down.shape[0]}")
return dump_lora(lora_down, lora_up)
def dump_transformer_svdq_lora(qmodel: DeepCompressorModel, layer_id: int) -> TensorDict:
tensors = {}
def reorder_adanorm_linear(weight: torch.Tensor) -> torch.Tensor:
oc, ic = weight.shape
assert oc % 6 == 0
return weight.reshape(6, oc // 6, ic).transpose(0, 1).reshape(oc, ic).contiguous()
def linear(key: str, **kwargs):
key_lora = key
key_branch = kwargs.pop("key_branch", key_lora)
key_smooth = kwargs.pop("key_smooth", key_branch)
return dump_linear_lora(qmodel, key_lora, key_branch, key_smooth, **kwargs)
prefix = f"transformer_blocks.{layer_id}"
if f"{prefix}.norm1.linear.lora_A.weight" in qmodel.lora:
lora_down = qmodel.lora[f"{prefix}.norm1.linear.lora_A.weight"]
lora_up = qmodel.lora[f"{prefix}.norm1.linear.lora_B.weight"]
tensors[f"norm1.linear.lora_down"] = lora_down
tensors[f"norm1.linear.lora_up"] = reorder_adanorm_linear(lora_up)
if f"{prefix}.norm1_context.linear.lora_A.weight" in qmodel.lora:
lora_down = qmodel.lora[f"{prefix}.norm1_context.linear.lora_A.weight"]
lora_up = qmodel.lora[f"{prefix}.norm1_context.linear.lora_B.weight"]
tensors[f"norm1_context.linear.lora_down"] = lora_down
tensors[f"norm1_context.linear.lora_up"] = reorder_adanorm_linear(lora_up)
merge_dict(tensors, dump_qkv_proj_svdq_lora(
qmodel,
(f"{prefix}.attn.to_q", f"{prefix}.attn.to_k", f"{prefix}.attn.to_v"),
f"{prefix}.attn.to_q",
f"{prefix}.attn.to_out.0"
), "qkv_proj.")
merge_dict(tensors, dump_qkv_proj_svdq_lora(
qmodel,
(f"{prefix}.attn.add_q_proj", f"{prefix}.attn.add_k_proj", f"{prefix}.attn.add_v_proj"),
f"{prefix}.attn.add_k_proj",
f"{prefix}.attn.to_out.0"
), "qkv_proj_context.")
merge_dict(tensors, linear(f"{prefix}.attn.to_out.0", key_smooth=None), "out_proj.")
merge_dict(tensors, linear(f"{prefix}.attn.to_add_out", key_smooth=None), "out_proj_context.")
merge_dict(tensors, linear(f"{prefix}.ff.net.0.proj"), "mlp_fc1.")
merge_dict(tensors, linear(f"{prefix}.ff.net.2", key_branch=f"{prefix}.ff.net.2.linear", shift_bias=True), "mlp_fc2.")
merge_dict(tensors, linear(f"{prefix}.ff_context.net.0.proj"), "mlp_context_fc1.")
merge_dict(tensors, linear(f"{prefix}.ff_context.net.2", key_branch=f"{prefix}.ff_context.net.2.linear", shift_bias=True), "mlp_context_fc2.")
return tensors
def dump_single_transformer_svdq_lora(qmodel: DeepCompressorModel, layer_id: int) -> TensorDict:
tensors = {}
def reorder_adanorm_linear(weight: torch.Tensor) -> torch.Tensor:
oc, ic = weight.shape
assert oc % 3 == 0
return weight.reshape(3, oc // 3, ic).transpose(0, 1).reshape(oc, ic).contiguous()
def linear(key: str, **kwargs):
key_lora = key
key_branch = kwargs.pop("key_branch", key_lora)
key_smooth = kwargs.pop("key_smooth", key_branch)
return dump_linear_lora(qmodel, key_lora, key_branch, key_smooth, **kwargs)
prefix = f"single_transformer_blocks.{layer_id}"
if f"{prefix}.norm.linear.lora_A.weight" in qmodel.lora:
lora_down = qmodel.lora[f"{prefix}.norm.linear.lora_A.weight"]
lora_up = qmodel.lora[f"{prefix}.norm.linear.lora_B.weight"]
tensors[f"norm.linear.lora_down"] = lora_down
tensors[f"norm.linear.lora_up"] = reorder_adanorm_linear(lora_up)
merge_dict(tensors, dump_qkv_proj_svdq_lora(
qmodel,
(f"{prefix}.attn.to_q", f"{prefix}.attn.to_k", f"{prefix}.attn.to_v"),
f"{prefix}.attn.to_q",
f"{prefix}.proj_out.linears.0"
), "qkv_proj.")
merge_dict(tensors, linear(f"{prefix}.proj_mlp", key_smooth=f"{prefix}.attn.to_q"), "mlp_fc1.")
# TODO
out_dim = 3072
merge_dict(tensors, linear(f"{prefix}.proj_out",
key_branch=f"{prefix}.proj_out.linears.0",
key_smooth=None,
range_ic=slice(0, out_dim)), "out_proj.")
merge_dict(tensors, linear(f"{prefix}.proj_out",
key_branch=f"{prefix}.proj_out.linears.1.linear",
shift_bias=True,
range_ic=slice(out_dim, None)), "mlp_fc2.")
return tensors
@torch.inference_mode()
def dump_flux_svdq_lora(qmodel: DeepCompressorModel, **kwargs) -> TensorDict:
tensors = {}
for i in range(19):
merge_dict(tensors, dump_transformer_svdq_lora(qmodel, i, **kwargs), f"transformer_blocks.{i}.")
for i in range(38):
merge_dict(tensors, dump_single_transformer_svdq_lora(qmodel, i, **kwargs), f"single_transformer_blocks.{i}.")
return tensors
if __name__ == "__main__":
lora_name = "realism"
if lora_name == "sketch":
qmodel = load_svdq_lora("model-dev", "../third_party/FLUX.1-dev-LoRA-Collections/sketch.safetensors")
elif lora_name == "realism":
qmodel = load_svdq_lora("model-dev", "../third_party/FLUX.1-dev-LoRA-Collections/realism.safetensors")
elif lora_name == "anime":
qmodel = load_svdq_lora("model-dev", "../third_party/sonny-anime-fixed/araminta_k_sonnyanime_fluxd_fixed.safetensors")
elif lora_name == "ghibsky":
qmodel = load_svdq_lora("model-dev", "../third_party/flux-ghibsky-illustration/lora.safetensors")
elif lora_name == "yarn":
qmodel = load_svdq_lora("model-dev", "../third_party/yarn_art_Flux_LoRA/pytorch_lora_weights.safetensors")
elif lora_name == "sketch2image":
qmodel = load_svdq_lora("model-dev", "sketch2image.safetensors")
else:
raise NotImplementedError
tensors = dump_flux_svdq_lora(qmodel)
for k, v in tensors.items():
assert not v.isnan().any()
assert not v.isinf().any()
safetensors.torch.save_file(
tensors,
f"/tmp/flux-lora-{lora_name}-bf16.safetensors")
#!/bin/bash
rundir=$(date +"run-$(hostname -s)-%Y%m%d-%H%M%S")
mkdir -p $rundir
function run() {
echo config=$config
echo args=$@
python3 run_flux.py --steps 4 "$@" > >(tee $rundir/stdout-s4-$config.log) 2> >(tee $rundir/stderr-s4-$config.log)
python3 run_flux.py --steps 25 "$@" > >(tee $rundir/stdout-s25-$config.log) 2> >(tee $rundir/stderr-s25-$config.log)
python3 run_flux.py --steps 50 "$@" > >(tee $rundir/stdout-s50-$config.log) 2> >(tee $rundir/stderr-s50-$config.log)
if [ $? -eq 0 ]; then
nsys profile --cuda-memory-usage true -o $rundir/report-$config.nsys-rep python3 run_flux.py --steps 4 "$@"
fi
}
config=bf16-compile
run --config bf16 --compile
config=bf16-t5-compile
run --config bf16-t5 --compile
config=int8dq-compile
run --config bf16 --torchao --compile
config=int8dq-t5-compile
run --config bf16-t5 --torchao --compile
config=int8dq-nocompile
run --config bf16 --torchao
config=int8dq-t5-nocompile
run --config bf16-t5 --torchao
for cfg in svdq svdq-t5 w4a4 w4a4-t5 bf16 bf16-t5 nf4 nf4-t5; do
config=$cfg
run --config $cfg
config=$cfg-ol1
run --config $cfg --offload 1
config=$cfg-ol2
run --config $cfg --offload 2
done
\ No newline at end of file
import torch
from torch.nn import functional as F
from dump_flux import group_scale
def compare(
ref: torch.Tensor,
v: torch.Tensor,
refname: str,
vname: str,
list_diff: bool = False):
print(f"== COMPARE v={vname} vs ref={refname}")
diff = v - ref
print(f" - diff = {diff}")
if list_diff:
print(f" - diffs at {diff.nonzero()}")
mse = diff.square().mean()
print(f" - mse = {mse}")
nmse = mse / ref.square().mean()
print(f" - nmse = {nmse}")
print(f" - mean(v/ref)={v.mean()}/{ref.mean()}")
print(f" - var(v/ref)={v.var()}/{ref.var()}")
print(f"== ")
print()
def print_debug_results(debug_results: dict[str, torch.Tensor], is_ref: bool = False):
ref = 'REF' if is_ref else ''
for k, v in debug_results.items():
has_nan = v.isnan().any()
has_inf = v.isinf().any()
if v.dtype.is_floating_point:
print(f" {ref} {k}: {v.shape} ({v.dtype}) has_nan={has_nan} has_inf={has_inf} max={v.max()} min={v.min()} mean={v.mean()} var={v.var()}")
else:
print(f" {ref} {k}: {v.shape} ({v.dtype})")
if has_nan:
cnt = v.isnan().count_nonzero()
print(f" {ref} -- {cnt} ({cnt / v.numel() * 100}%) nans at {v.isnan().nonzero()[0:10]}")
if has_inf:
cnt = v.isinf().count_nonzero()
print(f" {ref} -- {cnt} ({cnt / v.numel() * 100}%) infs at {v.isinf().nonzero()[0:10]}")
print(f" {ref} -- {v}")
print()
def fakequant(
act: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
group_size: int = 64,
force_cuda: bool = False,
):
oc, ic = weight.shape
batch_size = act.shape[0]
assert act.shape[1] == ic
# [oc, ic // group_size]
wscales = group_scale(weight, num_bits=4, group_size=group_size)
qweight = weight.reshape(oc, ic // group_size, group_size).to(dtype=torch.float32) / wscales[..., None]
qweight = qweight.round().clamp(-8, 7)
qweight_i = qweight.int()
qweight = qweight * wscales[..., None]
qweight = qweight.to(weight.dtype)
qweight = qweight.reshape(oc, ic)
# print(f"qweight = {qweight}")
print_debug_results({"qweight": qweight})
# [batch_size, ic // group_size]
ascales = group_scale(act, num_bits=4, group_size=group_size).to(dtype=weight.dtype)
qact = act.reshape(batch_size, ic // group_size, group_size).to(dtype=torch.float32) / ascales[..., None]
qact = qact.round().clamp(-8, 7)
qact_i = qact.int()
print_debug_results({"qact_i": qact_i})
qact = qact * ascales[..., None]
qact = qact.to(act.dtype)
qact = qact.reshape(batch_size, ic)
# print(f"qact = {qact}")
print_debug_results({"qact": qact})
outref_q = F.linear(qact.to(qweight.dtype), qweight, bias)
# print(f"outref_q = {outref_q}")
print_debug_results({"outref_q": outref_q})
###
if force_cuda:
qweight_i = qweight_i.to("cuda")
qact_i = qact_i.to("cuda")
wscales = wscales.to("cuda")
ascales = ascales.to("cuda")
bias = bias.to("cuda")
qweight = qweight_i
qact = qact_i
qweight = qweight.reshape(oc, ic // group_size, group_size).transpose(0, 1).transpose(1, 2)
qact = qact.reshape(batch_size, ic // group_size, group_size).transpose(0, 1)
# [ic // group_size, batch_size, oc]
psum = torch.bmm(qact.float(), qweight.float())
print(f"psum_i ({psum.shape}) = {psum} ")
# print(psum[:, 0, 23])
# print(f"ascales = {ascales}")
print_debug_results({"ascales": ascales})
print(f"ascales[0:16] = {ascales[0:16, 0]}")
ws1 = wscales.transpose(0, 1).reshape(ic // group_size, 1, oc).repeat(1, batch_size, 1)
as1 = ascales.transpose(0, 1).reshape(ic // group_size, batch_size, 1).repeat(1, 1, oc)
scales = ws1 * as1
print(f"scales = {scales}")
# print(scales[:, 0, 23])
psum = psum.to(dtype=act.dtype).float()
psum = psum * scales
print(f"psum ({psum.shape}) = {psum} ")
# print(psum[:, 0, 23])
# outref_q2 = psum.sum(dim=0) # .to(layer.weight.dtype)
outref_q2 = torch.zeros_like(psum[0])
for i in range(psum.shape[0]):
outref_q2 = (outref_q2 + psum[i]).to(act.dtype)
outref_q2 += bias[None, ...]
# print(f"outref_q2 = {outref_q2}")
print_debug_results({"outref_q2": outref_q2})
# print(outref_q2[0, 23])
if force_cuda:
outref_q2 = outref_q2.to(act.device)
return outref_q, outref_q2
from safetensors.torch import safe_open, save_file
def main():
input_path1 = "app/i2i/pretrained/converted/sketch.safetensors"
input_path2 = "app/i2i/pretrained/original/flux-lora-sketch2image-bf16.safetensors"
sd1 = {}
with safe_open(input_path1, framework="pt") as f:
for k in f.keys():
sd1[k] = f.get_tensor(k)
sd2 = {}
with safe_open(input_path2, framework="pt") as f:
for k in f.keys():
sd2[k] = f.get_tensor(k)
for k in sd1.keys():
if "lora" not in k:
print(k)
sd2[k.replace("transformer.", "")] = sd1[k]
save_file(sd2, "svdq-flux.1-pix2pix-turbo-sketch2image.safetensors")
if __name__ == "__main__":
main()
import math
import torch
import torch.nn as nn
def make_divisible(c, divisor):
return (c + divisor - 1) // divisor
def calculate_zeros_width(in_features, group_size=128, pack_num=8):
if group_size >= 128:
size_multiplier = 1
elif group_size == 64:
size_multiplier = 2
elif group_size == 32:
size_multiplier = 4
else:
raise NotImplementedError
base_width = make_divisible(in_features // group_size, pack_num)
base_width = make_divisible(base_width, size_multiplier) * size_multiplier
return base_width
def pack_intweight(unpacked_qweight, interleave, kstride):
# unpacked_qweight: [N, K]
N = unpacked_qweight.shape[0]
K = unpacked_qweight.shape[1]
Packed_Kernel = unpacked_qweight.cpu().numpy().reshape(N, K // 32, 32)
# np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...]
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4)
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 32)
# reorder each 8 weights for fast dequantization
# [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7]
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 8)
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3)
Packed_Kernel = Packed_Kernel.reshape(N, K)
# interleaving every four rows
Packed_Kernel = Packed_Kernel.reshape(
N // interleave, interleave, K // kstride, kstride
)
# N // 4, K // 64, 4, 64
Packed_Kernel = Packed_Kernel.transpose(0, 2, 1, 3)
Packed_Kernel = Packed_Kernel.reshape(
N // interleave, K // kstride, kstride, interleave
)
# Packing -> (N // 4, K // 64, 64)
Packed_Kernel = (
Packed_Kernel[..., 0]
| (Packed_Kernel[..., 1] << 4)
| (Packed_Kernel[..., 2] << 8)
| (Packed_Kernel[..., 3] << 12)
)
# reshape to (N // 4, K), FP16 format
Packed_Kernel = Packed_Kernel.reshape(N // interleave, K)
qweight = (
torch.tensor(Packed_Kernel.astype("int16"))
.to(unpacked_qweight.device)
.contiguous()
)
return qweight
def pseudo_quantize_tensor(
w, n_bit=8, zero_point=True, q_group_size=-1,
) -> tuple[torch.Tensor, torch.Tensor]:
org_w_shape = w.shape
if q_group_size > 0:
assert org_w_shape[-1] % q_group_size == 0
w = w.reshape(-1, q_group_size)
assert w.dim() == 2
if zero_point:
max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
else: # we actually never used this
# assert min_val is None
max_val = w.abs().amax(dim=1, keepdim=True)
max_val = max_val.clamp(min=1e-5)
max_int = 2 ** (n_bit - 1) - 1
min_int = -max_int
scales = max_val / max_int
zeros = torch.full_like(scales, fill_value=-min_int)
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0
w = w.reshape(org_w_shape)
return scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
def dump_linear_awq(
weight: torch.Tensor,
bias: torch.Tensor,
w_bit: int,
group_size: int,
zero_point: bool = True
) -> dict[str, torch.Tensor]:
scales, zeros = pseudo_quantize_tensor(weight, w_bit, zero_point, group_size)
print(scales.shape)
print(zeros.shape)
tensors = {}
dtype = weight.dtype
oc, ic = weight.shape
# need scales and zeros info for real quantization
assert scales is not None and zeros is not None
scale_zeros = zeros * scales
pack_num = 32 // w_bit
qscales = torch.zeros(
(
scales.shape[0],
calculate_zeros_width(ic, group_size) * pack_num,
),
dtype=dtype,
device=scales.device,
)
qscales[:, : scales.shape[1]] = scales
# awq_linear.scales = scales.clone().half()
tensors["wscales"] = qscales.transpose(1, 0).contiguous()
if bias is not None:
tensors["bias"] = bias.clone()
if False:
intweight = []
for idx in range(ic):
intweight.append(
torch.round(
(weight.data[:, idx] + scale_zeros[:, idx // group_size])
/ qscales[:, idx // group_size]
).clamp(0, 15 if zero_point else 14).to(torch.int)[:, None]
)
print(intweight[0].shape)
intweight = torch.cat(intweight, dim=1)
print(intweight.shape)
intweight_ref = intweight
# intweight = intweight.t().contiguous()
assert ic % group_size == 0
intweight = weight.reshape(oc, ic // group_size, group_size)
# print(f"{intweight.shape} {scale_zeros[..., None].shape} {qscales[..., None].shape}")
intweight = (intweight + scale_zeros[..., None]) / qscales[..., None]
intweight = intweight.round_()
intweight = intweight.clamp_(0, 15 if zero_point else 14)
intweight = intweight.to(dtype=torch.int32)
intweight = intweight.reshape(oc, ic)
if False:
print(intweight_ref - intweight)
assert not (intweight_ref - intweight != 0).any()
tensors["qweight"] = pack_intweight(
intweight.contiguous(), interleave=4, kstride=64
)
zeros = zeros.to(dtype=torch.int32)
scaled_zeros = torch.zeros_like(qscales)
# scaled_zeros[:, :scales.shape[1]] = -(qscales[:, :scales.shape[1]] * (zeros.to(torch.float32) - 8.0)).to(torch.float16)
scaled_zeros[:, : scales.shape[1]] = -(
qscales[:, : scales.shape[1]] * (zeros.to(torch.float32))
).to(dtype)
tensors["wzeros"] = scaled_zeros.transpose(1, 0).contiguous()
return tensors
import time
import argparse
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
import nunchaku.pipelines.flux
def get_pipe(config: str, dev: bool) -> FluxPipeline:
version = "dev" if dev else "schnell"
dtype = torch.bfloat16
qencoder_path = "/NFS/raid0/user/zhangzk/models/flux-t5-tinychat-v2.pt"
if config.startswith("svdq"):
pipe = nunchaku.pipelines.flux.from_pretrained(
f"black-forest-labs/FLUX.1-{version}",
torch_dtype=dtype,
qmodel_path=f"/NFS/raid0/user/zhangzk/models/flux{'-dev' if dev else ''}-svdq-19-38-divsmooth-shift-ada-bf16.safetensors",
qencoder_path=qencoder_path if config == "svdq-t5" else None
)
elif config.startswith("w4a4"):
pipe = nunchaku.pipelines.flux.from_pretrained(
f"black-forest-labs/FLUX.1-{version}",
torch_dtype=dtype,
qmodel_path=f"/NFS/raid0/user/zhangzk/models/flux{'-dev' if dev else ''}-divsmooth-shift-ada-bf16.safetensors",
qencoder_path=qencoder_path if config == "w4a4-t5" else None
)
elif config.startswith("bf16"):
pipe = FluxPipeline.from_pretrained(
f"black-forest-labs/FLUX.1-{version}",
torch_dtype=dtype,
)
if config == "bf16-t5":
nunchaku.pipelines.flux.quantize_t5(pipe, qencoder_path)
elif config.startswith("nf4"):
from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
from accelerate import init_empty_weights
from convert_nf4_flux import _replace_with_bnb_linear, create_quantized_param, check_quantized_param
converted_state_dict = torch.load(f"/NFS/raid0/user/zhangzk/models/flux1-{version}-nf4.pt")
with init_empty_weights():
config = FluxTransformer2DModel.load_config(f"black-forest-labs/flux.1-{version}", subfolder="transformer")
model = FluxTransformer2DModel.from_config(config).to(dtype)
_replace_with_bnb_linear(model, "nf4")
for param_name, param in converted_state_dict.items():
param = param.to(dtype)
print(f"{param_name}: {param.shape} check_quantized_param={check_quantized_param(model, param_name)}")
if not check_quantized_param(model, param_name):
set_module_tensor_to_device(model, param_name, device=0, value=param)
else:
create_quantized_param(model, param, param_name, target_device=0)
pipe = FluxPipeline.from_pretrained(f"black-forest-labs/flux.1-{version}", transformer=model, torch_dtype=dtype)
if config == "nf4-t5":
nunchaku.pipelines.flux.quantize_t5(pipe, qencoder_path)
else:
raise NotImplementedError
return pipe
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="svdq", choices=["svdq", "svdq-t5", "w4a4", "w4a4-t5", "bf16", "bf16-t5", "nf4", "nf4-t5"])
parser.add_argument("--offload", type=int, default=0)
parser.add_argument("--steps", type=int, default=50)
parser.add_argument("--dev", action="store_true")
parser.add_argument("--torchao", action="store_true")
parser.add_argument("--compile", action="store_true")
args = parser.parse_args()
print(f"Use config {args.config}")
if args.offload > 0:
print(f"Use offloading level {args.offload}")
pipe = get_pipe(args.config, args.dev)
print(pipe)
if args.torchao:
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
# pipe.transformer = autoquant(pipe.transformer, error_on_unseen=False)
quantize_(pipe.transformer, int8_dynamic_activation_int8_weight())
if args.offload == 2:
pipe.enable_sequential_cpu_offload()
elif args.offload == 1:
pipe.enable_model_cpu_offload()
elif args.offload == 0:
pipe.to("cuda:0")
else:
raise NotImplementedError
# assert isinstance(pipe, FluxPipeline)
if args.compile:
pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(
pipe.transformer, mode="max-autotune", fullgraph=True
)
prompt = "A cat holding a sign that says hello world"
print(f"Using prompt '{prompt}'")
print(f"Run {args.steps} steps")
latencies = []
for i in range(5):
start_time = time.time()
out = pipe(
prompt=prompt,
guidance_scale=0,
num_inference_steps=args.steps,
generator=torch.Generator(device="cpu").manual_seed(233),
).images[0]
end_time = time.time()
latencies.append(end_time - start_time)
torch.cuda.empty_cache()
latencies = sorted(latencies)
latencies = latencies[1:-1]
out.save("output.png")
print(f"Elapsed: {sum(latencies) / len(latencies)} seconds")
print(f"Torch max_memory_allocated={torch.cuda.max_memory_allocated()}")
\ No newline at end of file
import time
import torch
import diffusers
from diffusers import FluxPipeline
import nunchaku.pipelines.flux
if __name__ == "__main__":
QUANT = False
SEED = 1
DEV = True
LORA_NAME = "anime"
pipe = nunchaku.pipelines.flux.from_pretrained(
f"black-forest-labs/FLUX.1-{'dev' if DEV else 'schnell'}",
torch_dtype=torch.bfloat16,
qmodel_path=f"/NFS/raid0/user/zhangzk/models/flux{'-dev' if DEV else ''}-svdq-19-38-divsmooth-shift-ada-bf16.safetensors",
qencoder_path="/NFS/raid0/user/zhangzk/models/flux-t5-tinychat-v2.pt" if QUANT else None,
)
if LORA_NAME:
pipe.transformer.nunchaku_update_params(f"/tmp/flux-lora-{LORA_NAME}-bf16.safetensors")
pipe.transformer.nunchaku_set_lora_scale(0.4)
print("Moving model to CUDA")
pipe.to("cuda:0")
print("Done")
# prompt = "A cat holding a sign that says hello world"
# prompt = "A cyberpunk cat holding a huge neon sign that says \"SVDQuant is lite and fast\""
prompt = "girl, neck tuft, white hair ,sheep horns, blue eyes, nm22 style"
# prompt = "GHIBSKY style, the most beautiful place in the universe"
# prompt = "the joker, yarn art style"
print(f"Using prompt '{prompt}'")
latencies = []
diffusers.training_utils.set_seed(SEED)
start_time = time.time()
out = pipe(
prompt=prompt,
guidance_scale=3.5 if DEV else 0,
num_inference_steps=50 if DEV else 4,
generator=torch.Generator(device="cpu").manual_seed(SEED),
).images[0]
end_time = time.time()
latencies.append(end_time - start_time)
out.save(f"output{'-dev' if DEV else ''}-{SEED}-{'quant' if QUANT else 'noquant'}.png")
print(f"Elapsed: {sum(latencies) / len(latencies)} seconds")
import torch
from nunchaku.pipelines import flux as nunchaku_flux
pipeline = nunchaku_flux.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdquant-models/svdq-int4-flux.1-schnell.safetensors", # download from Huggingface
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0]
image.save("example.png")
__version__ = "0.0.0beta0"
\ No newline at end of file
#pragma once
#include "interop/torch.h"
#include "FluxModel.h"
#include "Serialization.h"
#include "debug.h"
#include "Linear.h"
class QuantizedFluxModel { // : public torch::CustomClassHolder {
public:
void init(bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedFluxModel");
net = std::make_unique<FluxModel>(bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
void reset() {
debugContext.reset();
net.reset();
Tensor::synchronizeDevice();
trimMemory();
Tensor::synchronizeDevice();
}
void load(std::string path, bool partial = false) {
checkModel();
spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path);
std::shared_ptr<SafeTensors> provider = std::make_shared<SafeTensors>(path);
net->loadParams(*provider, partial);
Tensor::synchronizeDevice();
spdlog::info("Done.");
}
torch::Tensor forward(
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context,
torch::Tensor rotary_emb_single)
{
checkModel();
spdlog::debug("QuantizedFluxModel forward");
hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous();
temb = temb.contiguous();
rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous();
rotary_emb_single = rotary_emb_single.contiguous();
Tensor result = net->forward(
from_torch(hidden_states),
from_torch(encoder_hidden_states),
from_torch(temb),
from_torch(rotary_emb_img),
from_torch(rotary_emb_context),
from_torch(rotary_emb_single)
);
torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice();
return output;
}
std::tuple<torch::Tensor, torch::Tensor> forward_layer(
int64_t idx,
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context)
{
spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous();
temb = temb.contiguous();
rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous();
auto &&[result_img, result_txt] = net->transformer_blocks.at(idx)->forward(
from_torch(hidden_states),
from_torch(encoder_hidden_states),
from_torch(temb),
from_torch(rotary_emb_img),
from_torch(rotary_emb_context),
0.0f
);
hidden_states = to_torch(result_img);
encoder_hidden_states = to_torch(result_txt);
Tensor::synchronizeDevice();
return { hidden_states, encoder_hidden_states };
}
torch::Tensor forward_single_layer(
int64_t idx,
torch::Tensor hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_single)
{
spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx);
hidden_states = hidden_states.contiguous();
temb = temb.contiguous();
rotary_emb_single = rotary_emb_single.contiguous();
Tensor result = net->single_transformer_blocks.at(idx)->forward(
from_torch(hidden_states),
from_torch(temb),
from_torch(rotary_emb_single)
);
hidden_states = to_torch(result);
Tensor::synchronizeDevice();
return hidden_states;
}
void disableMemoryAutoRelease() {
int device;
checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
uint64_t threshold = UINT64_MAX;
checkCUDA(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
}
void trimMemory() {
int device;
checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
size_t bytesToKeep = 0;
checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep));
}
void startDebug() {
debugContext = std::make_unique<DebugContext>();
}
void stopDebug() {
debugContext.reset();
}
auto getDebugResults() {
// c10::Dict<std::string, torch::Tensor> result;
std::map<std::string, torch::Tensor> result;
if (debugContext) {
for (auto &&[key, value] : debugContext->tensors) {
// result.insert(key, to_torch(value));
result[key] = to_torch(value);
}
}
return result;
}
// must be called after loading lora
// skip specific ranks in W4A4 layers
void setLoraScale(int skipRanks, float scale) {
if (skipRanks % 16 != 0) {
throw std::invalid_argument("skipRanks must be multiples of 16");
}
spdlog::info("Set lora scale to {} (skip {} ranks)", scale, skipRanks);
net->traverse([&](Module *module) {
if (auto *m = dynamic_cast<GEMV_AWQ *>(module)) {
m->lora_scale = scale;
} else if (auto *m = dynamic_cast<GEMM_W4A4 *>(module)) {
for (int i = 0; i < skipRanks / 16; i++) {
m->lora_scales[i] = 1.0f;
}
for (int i = skipRanks / 16; i < m->lora_scales.size(); i++) {
m->lora_scales[i] = scale;
}
}
});
}
private:
void checkModel() {
if (!net) {
throw std::runtime_error("Model not initialized");
}
}
private:
std::unique_ptr<FluxModel> net;
std::unique_ptr<DebugContext> debugContext;
};
\ No newline at end of file
#pragma once
#include "interop/torch.h"
#include "Serialization.h"
#include "Linear.h"
#include "debug.h"
#include "kernels/gemm_w4a4.h"
#include "kernels/awq/gemv_awq.h"
class QuantizedGEMM { // : public torch::CustomClassHolder {
public:
void init(int64_t in_features, int64_t out_features, bool bias, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedGEMM");
size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W4A4>((int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
void reset() {
debugContext.reset();
net.reset();
Tensor::synchronizeDevice();
}
void load(std::string path) {
checkModel();
spdlog::info("Loading weights from {}", path);
std::shared_ptr<SafeTensors> provider = std::make_shared<SafeTensors>(path);
net->loadParams(*provider);
Tensor::synchronizeDevice();
}
torch::Tensor forward(torch::Tensor x) {
checkModel();
std::cerr << "QuantizedGEMM forward" << std::endl;
x = x.contiguous();
Tensor result = std::get<Tensor>(net->forward(
from_torch(x)
));
torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice();
return output;
}
std::string dumpTensorBF16(Tensor x) {
std::stringstream ss;
for (int i = 0; i < 256; i++) {
ss << spdlog::fmt_lib::format("{:.3f} ", (float)(x.data_ptr<__nv_bfloat16>()[i]));
}
ss << std::endl;
return ss.str();
}
std::string dumpTensorINT4(Tensor x) {
using spdlog::fmt_lib::format;
const int M = x.shape[0];
const int K = x.shape[1] * 2;
assert(x.dtype() == Tensor::INT8);
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4)
constexpr int BLOCK_M = 256;
constexpr int WARP_K = 64;
constexpr int NUM_WARPS = 8;
constexpr int WARP_M_TILES = 2;
constexpr int WARP_SIZE = 32;
std::stringstream ss;
for (int bm = 0; bm < M / BLOCK_M; bm++) {
for (int bn = 0; bn < K / WARP_K; bn++) {
for (int warpId = 0; warpId < NUM_WARPS; warpId++) {
ss << format("[bm={},bn={},warp={}] ", bm, bn, warpId);
const int offset = ((bm * (K / WARP_K) + bn) * NUM_WARPS + warpId) * WARP_M_TILES * WARP_SIZE * 4;
for (int i = 0; i < 16; i++) {
assert(offset + i < x.numel() / 4);
uint32_t val = x.data_ptr<uint32_t>()[offset + i];
ss << "{";
for (int j = 0; j < 8; j++) {
int i4val = (val >> (j * 4)) & 0xf;
if (i4val & 0x8) {
i4val = -((~i4val & 0x7) + 1);
}
ss << format("{} ", i4val);
}
ss << format("}} {:x} ", val);
}
ss << std::endl;
}
}
}
ss << std::endl;
return ss.str();
}
void quantize(torch::Tensor x) {
checkModel();
spdlog::debug("QuantizedGEMM quantize");
x = x.contiguous();
auto qout = net->quantize(
from_torch(x)
);
Tensor act = qout.act.copy(Device::cpu());
Tensor ascales = qout.ascales.copy(Device::cpu());
Tensor lora_act = qout.lora_act.copy(Device::cpu());
Tensor::synchronizeDevice();
spdlog::debug("act = {}", dumpTensorINT4(act));
spdlog::debug("ascales = {}", dumpTensorBF16(ascales));
}
void gemm(
c10::optional<torch::Tensor> act, // packed act [M, K / 2]
c10::optional<torch::Tensor> wgt, // packed act [N, K / 2]
c10::optional<torch::Tensor> out, // linear [M, N]
c10::optional<torch::Tensor> qout, // packed act [M, N / 2]
c10::optional<torch::Tensor> ascales, // packed as [K / 64, M]
c10::optional<torch::Tensor> wscales, // packed ws [K / 64, N]
c10::optional<torch::Tensor> oscales, // packed as [N / 64, M]
c10::optional<torch::Tensor> poolout, // linear [M / PoolSize, N]
c10::optional<torch::Tensor> lora_act_in, // packed lora_act [M, R]
c10::optional<torch::Tensor> lora_up, // packed lora_wgt [N, R]
c10::optional<torch::Tensor> lora_down, // packed lora_wgt [N, R]
c10::optional<torch::Tensor> lora_act_out, // packed lora_act [M, R]
c10::optional<torch::Tensor> norm_q, // linear [HEAD_DIM]
c10::optional<torch::Tensor> norm_k, // linear [HEAD_DIM]
c10::optional<torch::Tensor> rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
c10::optional<torch::Tensor> bias, // packed ws [N]
c10::optional<torch::Tensor> smooth_factor, // packed ws [N], for quantization of the next layer
bool act_unsigned,
std::vector<float> lora_scales
) {
std::cerr << "running gemm_w4a4: " << std::endl;
auto getTensor = [](c10::optional<torch::Tensor> &t) {
Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{};
if (ret.valid()) {
std::cerr << " " << ret.shape.str() << std::endl;
} else {
std::cerr << " <invalid>" << std::endl;
}
return ret;
};
gemm_w4a4(
getTensor(act ),
getTensor(wgt ),
getTensor(out ),
getTensor(qout ),
getTensor(ascales ),
getTensor(wscales ),
getTensor(oscales ),
getTensor(poolout ),
getTensor(lora_act_in ),
getTensor(lora_up ),
getTensor(lora_down ),
getTensor(lora_act_out ),
getTensor(norm_q ),
getTensor(norm_k ),
getTensor(rotary_emb ),
getTensor(bias ),
getTensor(smooth_factor),
act_unsigned,
lora_scales
);
Tensor::synchronizeDevice();
}
torch::Tensor gemv_awq(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int64_t m,
int64_t n,
int64_t k,
int64_t group_size)
{
Tensor result = ::gemv_awq(
from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous()),
(int)m,
(int)n,
(int)k,
(int)group_size
);
torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice();
return output;
}
void startDebug() {
debugContext = std::make_unique<DebugContext>();
}
void stopDebug() {
debugContext.reset();
}
auto getDebugResults() {
// c10::Dict<std::string, torch::Tensor> result;
std::map<std::string, torch::Tensor> result;
if (debugContext) {
for (auto &&[key, value] : debugContext->tensors) {
// result.insert(key, to_torch(value));
result[key] = to_torch(value);
}
}
return result;
}
private:
void checkModel() {
if (!net) {
throw std::runtime_error("Model not initialized");
}
}
private:
std::unique_ptr<GEMM_W4A4> net;
std::unique_ptr<DebugContext> debugContext;
};
\ No newline at end of file
#include "gemm.h"
#include "flux.h"
#include <pybind11/pybind11.h>
// TORCH_LIBRARY(diffuxer, m) {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel")
// .def(torch::init<>())
.def(py::init<>())
.def("init", &QuantizedFluxModel::init,
py::arg("bf16"),
py::arg("deviceId")
)
.def("reset", &QuantizedFluxModel::reset)
.def("load", &QuantizedFluxModel::load,
py::arg("path"),
py::arg("partial") = false
)
.def("forward", &QuantizedFluxModel::forward)
.def("forward_layer", &QuantizedFluxModel::forward_layer)
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
.def("disableMemoryAutoRelease", &QuantizedFluxModel::disableMemoryAutoRelease)
.def("trimMemory", &QuantizedFluxModel::trimMemory)
.def("startDebug", &QuantizedFluxModel::startDebug)
.def("stopDebug", &QuantizedFluxModel::stopDebug)
.def("getDebugResults", &QuantizedFluxModel::getDebugResults)
.def("setLoraScale", &QuantizedFluxModel::setLoraScale)
;
py::class_<QuantizedGEMM>(m, "QuantizedGEMM")
// .def(torch::init<>())
.def(py::init<>())
.def("init", &QuantizedGEMM::init)
.def("reset", &QuantizedGEMM::reset)
.def("load", &QuantizedGEMM::load)
.def("forward", &QuantizedGEMM::forward)
.def("quantize", &QuantizedGEMM::quantize)
.def("gemm", &QuantizedGEMM::gemm)
.def("gemv_awq", &QuantizedGEMM::gemv_awq)
.def("startDebug", &QuantizedGEMM::startDebug)
.def("stopDebug", &QuantizedGEMM::stopDebug)
.def("getDebugResults", &QuantizedGEMM::getDebugResults)
;
}
import os
import types
import diffusers
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
from huggingface_hub import hf_hub_download
from packaging.version import Version
from torch import nn
from .._C import QuantizedFluxModel
SVD_RANK = 32
class NunchakuFluxModel(nn.Module):
def __init__(self, m: QuantizedFluxModel):
super().__init__()
self.m = m
self.dtype = torch.bfloat16
def forward(
self,
/,
hidden_states: torch.Tensor,
temb: torch.Tensor,
encoder_hidden_states: torch.Tensor,
image_rotary_emb: torch.Tensor,
joint_attention_kwargs=None,
):
batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1]
img_tokens = hidden_states.shape[1]
original_dtype = hidden_states.dtype
hidden_states = hidden_states.to(self.dtype)
encoder_hidden_states = encoder_hidden_states.to(self.dtype)
temb = temb.to(self.dtype)
assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1
assert image_rotary_emb.shape[1] == 1
assert image_rotary_emb.shape[2] == batch_size * (txt_tokens + img_tokens)
# [bs, tokens, head_dim / 2, 1, 2] (sincos)
image_rotary_emb = image_rotary_emb.reshape([batch_size, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...] # .to(self.dtype)
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...] # .to(self.dtype)
rotary_emb_single = image_rotary_emb # .to(self.dtype)
hidden_states = self.m.forward(
hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt, rotary_emb_single
)
hidden_states = hidden_states.to(original_dtype)
encoder_hidden_states = hidden_states[:, :txt_tokens, ...]
hidden_states = hidden_states[:, txt_tokens:, ...]
return encoder_hidden_states, hidden_states
## copied from diffusers 0.30.3
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
USE_SINCOS = True
if USE_SINCOS:
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 1, 2)
else:
out = out.view(batch_size, -1, dim // 2, 1, 1)
# stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
# out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.float()
class EmbedND(torch.nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
if Version(diffusers.__version__) >= Version("0.31.0"):
ids = ids[None, ...]
n_axes = ids.shape[-1]
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
return emb.unsqueeze(1)
def load_quantized_model(path: str, device: str | torch.device) -> QuantizedFluxModel:
device = torch.device(device)
assert device.type == "cuda"
m = QuantizedFluxModel()
m.disableMemoryAutoRelease()
m.init(True, 0 if device.index is None else device.index)
m.load(path)
return m
def inject_pipeline(pipe: FluxPipeline, m: QuantizedFluxModel) -> FluxPipeline:
net: FluxTransformer2DModel = pipe.transformer
net.pos_embed = EmbedND(dim=net.inner_dim, theta=10000, axes_dim=[16, 56, 56])
net.transformer_blocks = torch.nn.ModuleList([NunchakuFluxModel(m)])
net.single_transformer_blocks = torch.nn.ModuleList([])
def update_params(self: FluxTransformer2DModel, path: str):
if not os.path.exists(path):
hf_repo_id = os.path.dirname(path)
filename = os.path.basename(path)
path = hf_hub_download(repo_id=hf_repo_id, filename=filename)
block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxModel)
block.m.load(path, True)
def set_lora_scale(self: FluxTransformer2DModel, scale: float):
block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxModel)
block.m.setLoraScale(SVD_RANK, scale)
net.nunchaku_update_params = types.MethodType(update_params, net)
net.nunchaku_set_lora_scale = types.MethodType(set_lora_scale, net)
return pipe
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
safety_check_template = """You are a policy expert trying to help determine whether a user
prompt is in violation of the defined safety policies.
<start_of_turn>
Human Question: {user_prompt}
<end_of_turn>
Our safety principle is defined in the below:
The prompt must not seek to generate harmful, abusive, hateful, sexually explicit, or otherwise inappropriate content targeting individuals or protected groups.
Does the human question violate the above principle? Your answer must start
with 'Yes' or 'No'. And then walk through step by step to be sure we answer
correctly.
"""
class SafetyChecker:
def __init__(self, device: str | torch.device, disabled: bool = False):
if not disabled:
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained("google/shieldgemma-2b")
self.llm = AutoModelForCausalLM.from_pretrained("google/shieldgemma-2b", torch_dtype=torch.bfloat16).to(
device
)
self.disabled = disabled
def __call__(self, user_prompt: str, threshold: float = 0.2) -> bool:
if self.disabled:
return True
device = self.device
inputs = self.tokenizer(safety_check_template.format(user_prompt=user_prompt), return_tensors="pt").to(device)
with torch.no_grad():
logits = self.llm(**inputs).logits
# Extract the logits for the Yes and No tokens
vocab = self.tokenizer.get_vocab()
selected_logits = logits[0, -1, [vocab["Yes"], vocab["No"]]]
# Convert these logits to a probability with softmax
probabilities = torch.softmax(selected_logits, dim=0)
# Return probability of 'Yes'
score = probabilities[0].item()
return score < threshold
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment