Commit e3597f7e authored by Muyang Li's avatar Muyang Li
Browse files

remove dev scripts

parent 8431762a
"""
Utilities adapted from
* https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_bnb_4bit.py
* https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/bitsandbytes.py
"""
import torch
import bitsandbytes as bnb
from transformers.quantizers.quantizers_utils import get_module_from_name
import torch.nn as nn
from accelerate import init_empty_weights
def _replace_with_bnb_linear(
model,
method="nf4",
has_been_replaced=False,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
for name, module in model.named_children():
if isinstance(module, nn.Linear):
with init_empty_weights():
in_features = module.in_features
out_features = module.out_features
if method == "llm_int8":
model._modules[name] = bnb.nn.Linear8bitLt(
in_features,
out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=6.0,
)
has_been_replaced = True
else:
model._modules[name] = bnb.nn.Linear4bit(
in_features,
out_features,
module.bias is not None,
compute_dtype=torch.bfloat16,
compress_statistics=False,
quant_type="nf4",
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_bnb_linear(
module,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
return model, has_been_replaced
def check_quantized_param(
model,
param_name: str,
) -> bool:
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
# Add here check for loaded components' dtypes once serialization is implemented
return True
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
# bias could be loaded by regular set_module_tensor_to_device() from accelerate,
# but it would wrongly use uninitialized weight there.
return True
else:
return False
def create_quantized_param(
model,
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict=None,
unexpected_keys=None,
pre_quantized=False
):
module, tensor_name = get_module_from_name(model, param_name)
if tensor_name not in module._parameters:
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
old_value = getattr(module, tensor_name)
if tensor_name == "bias":
if param_value is None:
new_value = old_value.to(target_device)
else:
new_value = param_value.to(target_device)
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
module._parameters[tensor_name] = new_value
return
if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
raise ValueError("this function only loads `Linear4bit components`")
if (
old_value.device == torch.device("meta")
and target_device not in ["meta", torch.device("meta")]
and param_value is None
):
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
if pre_quantized:
if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (
param_name + ".quant_state.bitsandbytes__nf4" not in state_dict
):
raise ValueError(
f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components."
)
quantized_stats = {}
for k, v in state_dict.items():
# `startswith` to counter for edge cases where `param_name`
# substring can be present in multiple places in the `state_dict`
if param_name + "." in k and k.startswith(param_name):
quantized_stats[k] = v
if unexpected_keys is not None and k in unexpected_keys:
unexpected_keys.remove(k)
new_value = bnb.nn.Params4bit.from_prequantized(
data=param_value,
quantized_stats=quantized_stats,
requires_grad=False,
device=target_device,
)
else:
new_value = param_value.to("cpu")
kwargs = old_value.__dict__
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
print(f"{param_name}: new_value.quant_type={new_value.quant_type} quant_state={new_value.quant_state} storage={new_value.quant_storage} blocksize={new_value.blocksize}")
state = new_value.quant_state
print(f" -- state.code={state.code} dtype={state.dtype} blocksize={state.blocksize}")
module._parameters[tensor_name] = new_value
# generate.py
# from huggingface_hub import hf_hub_download
# from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
# from accelerate import init_empty_weights
# from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
# from convert_nf4_flux import _replace_with_bnb_linear, create_quantized_param, check_quantized_param
# from diffusers import FluxTransformer2DModel, FluxPipeline
# import safetensors.torch
# import gc
# import torch
# dtype = torch.bfloat16
# ckpt_path = hf_hub_download("black-forest-labs/flux.1-dev", filename="flux1-dev.safetensors")
# original_state_dict = safetensors.torch.load_file(ckpt_path)
# converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)
# del original_state_dict
# gc.collect()
# with init_empty_weights():
# config = FluxTransformer2DModel.load_config("black-forest-labs/flux.1-dev", subfolder="transformer")
# model = FluxTransformer2DModel.from_config(config).to(dtype)
# _replace_with_bnb_linear(model, "nf4")
# for param_name, param in converted_state_dict.items():
# param = param.to(dtype)
# if not check_quantized_param(model, param_name):
# set_module_tensor_to_device(model, param_name, device=0, value=param)
# else:
# create_quantized_param(model, param, param_name, target_device=0)
# del converted_state_dict
# gc.collect()
# print(compute_module_sizes(model)[""] / 1024 / 1204)
# pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype)
# pipe.enable_model_cpu_offload()
# prompt = "A mystic cat with a sign that says hello world!"
# image = pipe(prompt, guidance_scale=3.5, num_inference_steps=50, generator=torch.manual_seed(0)).images[0]
# image.save("flux-nf4-dev.png")
# model.push_to_hub("sayakpaul/flux.1-dev-nf4")
\ No newline at end of file
This diff is collapsed.
import torch
import safetensors
import torch.nn.functional as F
from dump_flux import DeepCompressorModel, TensorDict, pack_wscales, pack_lora, merge_dict, unsmooth
from typing import Optional
Lora = tuple[torch.Tensor, torch.Tensor]
def load_svdq_lora(path: str, lora_path: str) -> DeepCompressorModel:
result = DeepCompressorModel(
model=torch.load(f"{path}/model.pt", map_location="cpu"),
smooth=torch.load(f"{path}/smooth.pt", map_location="cpu"),
branch=torch.load(f"{path}/branch.pt", map_location="cpu"),
lora={}
)
with safetensors.safe_open(lora_path, framework="pt", device="cpu") as f:
for k in f.keys():
prefix = "transformer."
if k.startswith(prefix):
result.lora[k.removeprefix(prefix)] = f.get_tensor(k)
dtype = next(iter(result.branch.values()))["a.weight"].dtype
for k, v in result.lora.items():
if v.dtype != dtype:
print(f"Convert lora weight {k} from {v.dtype} to {dtype}")
result.lora[k] = v.to(dtype)
# for k, v in result.lora.items():
# v.fill_(0)
return result
# q/k/v [3072, ...] -> qkv [3072 * 3, ...]
def extend_qkv(input: torch.Tensor, id: int) -> torch.Tensor:
oc, ic = input.shape
tmp = torch.zeros([oc * 3, ic], dtype=input.dtype, device=input.device)
tmp[id*oc:(id+1)*oc, ...] = input
return tmp
def merge_lora(inputs: list[Lora]) -> Optional[Lora]:
if len(inputs) == 0:
return None
lora_downs = [x[0] for x in inputs]
lora_ups = [x[1] for x in inputs]
lora_down = torch.cat(lora_downs, dim=0)
lora_up = torch.cat(lora_ups, dim=1)
return (lora_down, lora_up)
def merge_lora_qkv(inputs: list[Lora]) -> list[Lora]:
if len(inputs) == 0:
return []
for x in inputs:
if not x[0].equal(inputs[0][0]):
return inputs
lora_down = inputs[0][0]
lora_ups = [x[1] for x in inputs]
lora_up = torch.sum(torch.stack(lora_ups), dim=0).to(lora_down.dtype)
return [(lora_down, lora_up)]
def dump_lora(lora_down: Optional[torch.Tensor], lora_up: Optional[torch.Tensor]) -> TensorDict:
if lora_down is None:
return {}
rank, ic = lora_down.shape
oc = lora_up.shape[0]
assert lora_up.shape == (oc, rank)
if rank % 16 != 0:
rank_pad = (rank + 16 - 1) // 16 * 16
tmp_down = torch.zeros([rank_pad, ic], dtype=lora_down.dtype, device=lora_down.device)
tmp_up = torch.zeros([oc, rank_pad], dtype=lora_down.dtype, device=lora_down.device)
tmp_down[:rank, ...] = lora_down
tmp_up[..., :rank] = lora_up
lora_down = tmp_down
lora_up = tmp_up
print(f"Pad lora rank from {rank} to {rank_pad}")
lora_down = pack_lora(lora_down.transpose(0, 1), is_lora_down=True)
lora_up = pack_lora(lora_up, is_lora_down=False)
tensors = {}
tensors["lora_down"] = lora_down
tensors["lora_up"] = lora_up
return tensors
def get_original_lora(qmodel: DeepCompressorModel, key_branch: str, key_smooth: Optional[str]) -> Lora:
dtype = qmodel.branch[key_branch]["a.weight"].dtype
smooth = qmodel.smooth[key_smooth].to(dtype).float() if key_smooth else None
return (
unsmooth(qmodel.branch[key_branch]["a.weight"], smooth),
qmodel.branch[key_branch]["b.weight"]
)
def dump_linear_lora(
qmodel: DeepCompressorModel,
key_lora: str,
key_branch: str,
key_smooth: str,
shift_bias: bool = False,
key_bias: Optional[str] = None,
range_ic: slice = slice(None, None, None)) -> TensorDict:
lora_original = get_original_lora(qmodel, key_branch, key_smooth)
if f"{key_lora}.lora_A.weight" in qmodel.lora:
# lora_down = qmodel.lora[f"{key}.lora_A.weight"][..., range_ic]
# lora_up = qmodel.lora[f"{key}.lora_B.weight"]
lora_new = (
qmodel.lora[f"{key_lora}.lora_A.weight"][..., range_ic],
qmodel.lora[f"{key_lora}.lora_B.weight"]
)
lora_down, lora_up = merge_lora([lora_original, lora_new])
rank, ic = lora_down.shape
oc = lora_up.shape[0]
assert lora_up.shape == (oc, rank)
print(f"linear at {key_lora} has rank {rank}")
tensors = dump_lora(lora_down, lora_up)
if shift_bias and False: # no longer need shift bias
if key_bias is None:
key_bias = f"{key_branch}.bias"
if key_bias in qmodel.model:
bias = qmodel.model[key_bias]
print(f"linear at {key_lora} apply shift_bias from original bias at {key_bias}")
else:
bias = torch.zeros([oc], dtype=lora_up.dtype, device=lora_up.device)
print(f"linear at {key_lora} apply shift_bias from empty original bias")
shift = torch.empty([ic], dtype=lora_down.dtype, device=lora_down.device)
shift = shift.fill_(0.171875)
delta = F.linear(F.linear(shift, lora_new[0]), lora_new[1])
print(f"shift_bias delta = {delta}")
bias -= delta
tensors["bias"] = pack_wscales(bias[..., None])[0]
return tensors
else:
print(f"linear at {key_lora} use original lora")
return dump_lora(*lora_original)
def dump_qkv_proj_svdq_lora(
qmodel: DeepCompressorModel,
key_qkv: tuple[str, str, str],
key_smooth: str,
key_smooth_out: str
) -> TensorDict:
dtype = qmodel.branch[key_smooth]["a.weight"].dtype
smooth_out = qmodel.smooth[key_smooth_out].to(dtype).float()
lora_original = get_original_lora(qmodel, key_smooth, key_smooth)
loras = []
for i in range(3):
key = key_qkv[i]
if f"{key}.lora_A.weight" in qmodel.lora:
lora_down = qmodel.lora[f"{key}.lora_A.weight"]
lora_up = qmodel.lora[f"{key}.lora_B.weight"]
if i == 2:
lora_up = (lora_up / smooth_out[..., None]).to(lora_up.dtype)
loras.append((lora_down, extend_qkv(lora_up, i)))
# print(loras)
lora_down, lora_up = merge_lora([lora_original, *merge_lora_qkv(loras)])
print(f"qkv_proj at {key_smooth} has rank {lora_down.shape[0]}")
return dump_lora(lora_down, lora_up)
def dump_transformer_svdq_lora(qmodel: DeepCompressorModel, layer_id: int) -> TensorDict:
tensors = {}
def reorder_adanorm_linear(weight: torch.Tensor) -> torch.Tensor:
oc, ic = weight.shape
assert oc % 6 == 0
return weight.reshape(6, oc // 6, ic).transpose(0, 1).reshape(oc, ic).contiguous()
def linear(key: str, **kwargs):
key_lora = key
key_branch = kwargs.pop("key_branch", key_lora)
key_smooth = kwargs.pop("key_smooth", key_branch)
return dump_linear_lora(qmodel, key_lora, key_branch, key_smooth, **kwargs)
prefix = f"transformer_blocks.{layer_id}"
if f"{prefix}.norm1.linear.lora_A.weight" in qmodel.lora:
lora_down = qmodel.lora[f"{prefix}.norm1.linear.lora_A.weight"]
lora_up = qmodel.lora[f"{prefix}.norm1.linear.lora_B.weight"]
tensors[f"norm1.linear.lora_down"] = lora_down
tensors[f"norm1.linear.lora_up"] = reorder_adanorm_linear(lora_up)
if f"{prefix}.norm1_context.linear.lora_A.weight" in qmodel.lora:
lora_down = qmodel.lora[f"{prefix}.norm1_context.linear.lora_A.weight"]
lora_up = qmodel.lora[f"{prefix}.norm1_context.linear.lora_B.weight"]
tensors[f"norm1_context.linear.lora_down"] = lora_down
tensors[f"norm1_context.linear.lora_up"] = reorder_adanorm_linear(lora_up)
merge_dict(tensors, dump_qkv_proj_svdq_lora(
qmodel,
(f"{prefix}.attn.to_q", f"{prefix}.attn.to_k", f"{prefix}.attn.to_v"),
f"{prefix}.attn.to_q",
f"{prefix}.attn.to_out.0"
), "qkv_proj.")
merge_dict(tensors, dump_qkv_proj_svdq_lora(
qmodel,
(f"{prefix}.attn.add_q_proj", f"{prefix}.attn.add_k_proj", f"{prefix}.attn.add_v_proj"),
f"{prefix}.attn.add_k_proj",
f"{prefix}.attn.to_out.0"
), "qkv_proj_context.")
merge_dict(tensors, linear(f"{prefix}.attn.to_out.0", key_smooth=None), "out_proj.")
merge_dict(tensors, linear(f"{prefix}.attn.to_add_out", key_smooth=None), "out_proj_context.")
merge_dict(tensors, linear(f"{prefix}.ff.net.0.proj"), "mlp_fc1.")
merge_dict(tensors, linear(f"{prefix}.ff.net.2", key_branch=f"{prefix}.ff.net.2.linear", shift_bias=True), "mlp_fc2.")
merge_dict(tensors, linear(f"{prefix}.ff_context.net.0.proj"), "mlp_context_fc1.")
merge_dict(tensors, linear(f"{prefix}.ff_context.net.2", key_branch=f"{prefix}.ff_context.net.2.linear", shift_bias=True), "mlp_context_fc2.")
return tensors
def dump_single_transformer_svdq_lora(qmodel: DeepCompressorModel, layer_id: int) -> TensorDict:
tensors = {}
def reorder_adanorm_linear(weight: torch.Tensor) -> torch.Tensor:
oc, ic = weight.shape
assert oc % 3 == 0
return weight.reshape(3, oc // 3, ic).transpose(0, 1).reshape(oc, ic).contiguous()
def linear(key: str, **kwargs):
key_lora = key
key_branch = kwargs.pop("key_branch", key_lora)
key_smooth = kwargs.pop("key_smooth", key_branch)
return dump_linear_lora(qmodel, key_lora, key_branch, key_smooth, **kwargs)
prefix = f"single_transformer_blocks.{layer_id}"
if f"{prefix}.norm.linear.lora_A.weight" in qmodel.lora:
lora_down = qmodel.lora[f"{prefix}.norm.linear.lora_A.weight"]
lora_up = qmodel.lora[f"{prefix}.norm.linear.lora_B.weight"]
tensors[f"norm.linear.lora_down"] = lora_down
tensors[f"norm.linear.lora_up"] = reorder_adanorm_linear(lora_up)
merge_dict(tensors, dump_qkv_proj_svdq_lora(
qmodel,
(f"{prefix}.attn.to_q", f"{prefix}.attn.to_k", f"{prefix}.attn.to_v"),
f"{prefix}.attn.to_q",
f"{prefix}.proj_out.linears.0"
), "qkv_proj.")
merge_dict(tensors, linear(f"{prefix}.proj_mlp", key_smooth=f"{prefix}.attn.to_q"), "mlp_fc1.")
# TODO
out_dim = 3072
merge_dict(tensors, linear(f"{prefix}.proj_out",
key_branch=f"{prefix}.proj_out.linears.0",
key_smooth=None,
range_ic=slice(0, out_dim)), "out_proj.")
merge_dict(tensors, linear(f"{prefix}.proj_out",
key_branch=f"{prefix}.proj_out.linears.1.linear",
shift_bias=True,
range_ic=slice(out_dim, None)), "mlp_fc2.")
return tensors
@torch.inference_mode()
def dump_flux_svdq_lora(qmodel: DeepCompressorModel, **kwargs) -> TensorDict:
tensors = {}
for i in range(19):
merge_dict(tensors, dump_transformer_svdq_lora(qmodel, i, **kwargs), f"transformer_blocks.{i}.")
for i in range(38):
merge_dict(tensors, dump_single_transformer_svdq_lora(qmodel, i, **kwargs), f"single_transformer_blocks.{i}.")
return tensors
if __name__ == "__main__":
lora_name = "realism"
if lora_name == "sketch":
qmodel = load_svdq_lora("model-dev", "../third_party/FLUX.1-dev-LoRA-Collections/sketch.safetensors")
elif lora_name == "realism":
qmodel = load_svdq_lora("model-dev", "../third_party/FLUX.1-dev-LoRA-Collections/realism.safetensors")
elif lora_name == "anime":
qmodel = load_svdq_lora("model-dev", "../third_party/sonny-anime-fixed/araminta_k_sonnyanime_fluxd_fixed.safetensors")
elif lora_name == "ghibsky":
qmodel = load_svdq_lora("model-dev", "../third_party/flux-ghibsky-illustration/lora.safetensors")
elif lora_name == "yarn":
qmodel = load_svdq_lora("model-dev", "../third_party/yarn_art_Flux_LoRA/pytorch_lora_weights.safetensors")
elif lora_name == "sketch2image":
qmodel = load_svdq_lora("model-dev", "sketch2image.safetensors")
else:
raise NotImplementedError
tensors = dump_flux_svdq_lora(qmodel)
for k, v in tensors.items():
assert not v.isnan().any()
assert not v.isinf().any()
safetensors.torch.save_file(
tensors,
f"/tmp/flux-lora-{lora_name}-bf16.safetensors")
#!/bin/bash
rundir=$(date +"run-$(hostname -s)-%Y%m%d-%H%M%S")
mkdir -p $rundir
function run() {
echo config=$config
echo args=$@
python3 run_flux.py --steps 4 "$@" > >(tee $rundir/stdout-s4-$config.log) 2> >(tee $rundir/stderr-s4-$config.log)
python3 run_flux.py --steps 25 "$@" > >(tee $rundir/stdout-s25-$config.log) 2> >(tee $rundir/stderr-s25-$config.log)
python3 run_flux.py --steps 50 "$@" > >(tee $rundir/stdout-s50-$config.log) 2> >(tee $rundir/stderr-s50-$config.log)
if [ $? -eq 0 ]; then
nsys profile --cuda-memory-usage true -o $rundir/report-$config.nsys-rep python3 run_flux.py --steps 4 "$@"
fi
}
config=bf16-compile
run --config bf16 --compile
config=bf16-t5-compile
run --config bf16-t5 --compile
config=int8dq-compile
run --config bf16 --torchao --compile
config=int8dq-t5-compile
run --config bf16-t5 --torchao --compile
config=int8dq-nocompile
run --config bf16 --torchao
config=int8dq-t5-nocompile
run --config bf16-t5 --torchao
for cfg in svdq svdq-t5 w4a4 w4a4-t5 bf16 bf16-t5 nf4 nf4-t5; do
config=$cfg
run --config $cfg
config=$cfg-ol1
run --config $cfg --offload 1
config=$cfg-ol2
run --config $cfg --offload 2
done
\ No newline at end of file
import torch
from torch.nn import functional as F
from dump_flux import group_scale
def compare(
ref: torch.Tensor,
v: torch.Tensor,
refname: str,
vname: str,
list_diff: bool = False):
print(f"== COMPARE v={vname} vs ref={refname}")
diff = v - ref
print(f" - diff = {diff}")
if list_diff:
print(f" - diffs at {diff.nonzero()}")
mse = diff.square().mean()
print(f" - mse = {mse}")
nmse = mse / ref.square().mean()
print(f" - nmse = {nmse}")
print(f" - mean(v/ref)={v.mean()}/{ref.mean()}")
print(f" - var(v/ref)={v.var()}/{ref.var()}")
print(f"== ")
print()
def print_debug_results(debug_results: dict[str, torch.Tensor], is_ref: bool = False):
ref = 'REF' if is_ref else ''
for k, v in debug_results.items():
has_nan = v.isnan().any()
has_inf = v.isinf().any()
if v.dtype.is_floating_point:
print(f" {ref} {k}: {v.shape} ({v.dtype}) has_nan={has_nan} has_inf={has_inf} max={v.max()} min={v.min()} mean={v.mean()} var={v.var()}")
else:
print(f" {ref} {k}: {v.shape} ({v.dtype})")
if has_nan:
cnt = v.isnan().count_nonzero()
print(f" {ref} -- {cnt} ({cnt / v.numel() * 100}%) nans at {v.isnan().nonzero()[0:10]}")
if has_inf:
cnt = v.isinf().count_nonzero()
print(f" {ref} -- {cnt} ({cnt / v.numel() * 100}%) infs at {v.isinf().nonzero()[0:10]}")
print(f" {ref} -- {v}")
print()
def fakequant(
act: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
group_size: int = 64,
force_cuda: bool = False,
):
oc, ic = weight.shape
batch_size = act.shape[0]
assert act.shape[1] == ic
# [oc, ic // group_size]
wscales = group_scale(weight, num_bits=4, group_size=group_size)
qweight = weight.reshape(oc, ic // group_size, group_size).to(dtype=torch.float32) / wscales[..., None]
qweight = qweight.round().clamp(-8, 7)
qweight_i = qweight.int()
qweight = qweight * wscales[..., None]
qweight = qweight.to(weight.dtype)
qweight = qweight.reshape(oc, ic)
# print(f"qweight = {qweight}")
print_debug_results({"qweight": qweight})
# [batch_size, ic // group_size]
ascales = group_scale(act, num_bits=4, group_size=group_size).to(dtype=weight.dtype)
qact = act.reshape(batch_size, ic // group_size, group_size).to(dtype=torch.float32) / ascales[..., None]
qact = qact.round().clamp(-8, 7)
qact_i = qact.int()
print_debug_results({"qact_i": qact_i})
qact = qact * ascales[..., None]
qact = qact.to(act.dtype)
qact = qact.reshape(batch_size, ic)
# print(f"qact = {qact}")
print_debug_results({"qact": qact})
outref_q = F.linear(qact.to(qweight.dtype), qweight, bias)
# print(f"outref_q = {outref_q}")
print_debug_results({"outref_q": outref_q})
###
if force_cuda:
qweight_i = qweight_i.to("cuda")
qact_i = qact_i.to("cuda")
wscales = wscales.to("cuda")
ascales = ascales.to("cuda")
bias = bias.to("cuda")
qweight = qweight_i
qact = qact_i
qweight = qweight.reshape(oc, ic // group_size, group_size).transpose(0, 1).transpose(1, 2)
qact = qact.reshape(batch_size, ic // group_size, group_size).transpose(0, 1)
# [ic // group_size, batch_size, oc]
psum = torch.bmm(qact.float(), qweight.float())
print(f"psum_i ({psum.shape}) = {psum} ")
# print(psum[:, 0, 23])
# print(f"ascales = {ascales}")
print_debug_results({"ascales": ascales})
print(f"ascales[0:16] = {ascales[0:16, 0]}")
ws1 = wscales.transpose(0, 1).reshape(ic // group_size, 1, oc).repeat(1, batch_size, 1)
as1 = ascales.transpose(0, 1).reshape(ic // group_size, batch_size, 1).repeat(1, 1, oc)
scales = ws1 * as1
print(f"scales = {scales}")
# print(scales[:, 0, 23])
psum = psum.to(dtype=act.dtype).float()
psum = psum * scales
print(f"psum ({psum.shape}) = {psum} ")
# print(psum[:, 0, 23])
# outref_q2 = psum.sum(dim=0) # .to(layer.weight.dtype)
outref_q2 = torch.zeros_like(psum[0])
for i in range(psum.shape[0]):
outref_q2 = (outref_q2 + psum[i]).to(act.dtype)
outref_q2 += bias[None, ...]
# print(f"outref_q2 = {outref_q2}")
print_debug_results({"outref_q2": outref_q2})
# print(outref_q2[0, 23])
if force_cuda:
outref_q2 = outref_q2.to(act.device)
return outref_q, outref_q2
from safetensors.torch import safe_open, save_file
def main():
input_path1 = "app/i2i/pretrained/converted/sketch.safetensors"
input_path2 = "app/i2i/pretrained/original/flux-lora-sketch2image-bf16.safetensors"
sd1 = {}
with safe_open(input_path1, framework="pt") as f:
for k in f.keys():
sd1[k] = f.get_tensor(k)
sd2 = {}
with safe_open(input_path2, framework="pt") as f:
for k in f.keys():
sd2[k] = f.get_tensor(k)
for k in sd1.keys():
if "lora" not in k:
print(k)
sd2[k.replace("transformer.", "")] = sd1[k]
save_file(sd2, "svdq-flux.1-pix2pix-turbo-sketch2image.safetensors")
if __name__ == "__main__":
main()
import math
import torch
import torch.nn as nn
def make_divisible(c, divisor):
return (c + divisor - 1) // divisor
def calculate_zeros_width(in_features, group_size=128, pack_num=8):
if group_size >= 128:
size_multiplier = 1
elif group_size == 64:
size_multiplier = 2
elif group_size == 32:
size_multiplier = 4
else:
raise NotImplementedError
base_width = make_divisible(in_features // group_size, pack_num)
base_width = make_divisible(base_width, size_multiplier) * size_multiplier
return base_width
def pack_intweight(unpacked_qweight, interleave, kstride):
# unpacked_qweight: [N, K]
N = unpacked_qweight.shape[0]
K = unpacked_qweight.shape[1]
Packed_Kernel = unpacked_qweight.cpu().numpy().reshape(N, K // 32, 32)
# np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...]
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4)
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 32)
# reorder each 8 weights for fast dequantization
# [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7]
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 8)
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3)
Packed_Kernel = Packed_Kernel.reshape(N, K)
# interleaving every four rows
Packed_Kernel = Packed_Kernel.reshape(
N // interleave, interleave, K // kstride, kstride
)
# N // 4, K // 64, 4, 64
Packed_Kernel = Packed_Kernel.transpose(0, 2, 1, 3)
Packed_Kernel = Packed_Kernel.reshape(
N // interleave, K // kstride, kstride, interleave
)
# Packing -> (N // 4, K // 64, 64)
Packed_Kernel = (
Packed_Kernel[..., 0]
| (Packed_Kernel[..., 1] << 4)
| (Packed_Kernel[..., 2] << 8)
| (Packed_Kernel[..., 3] << 12)
)
# reshape to (N // 4, K), FP16 format
Packed_Kernel = Packed_Kernel.reshape(N // interleave, K)
qweight = (
torch.tensor(Packed_Kernel.astype("int16"))
.to(unpacked_qweight.device)
.contiguous()
)
return qweight
def pseudo_quantize_tensor(
w, n_bit=8, zero_point=True, q_group_size=-1,
) -> tuple[torch.Tensor, torch.Tensor]:
org_w_shape = w.shape
if q_group_size > 0:
assert org_w_shape[-1] % q_group_size == 0
w = w.reshape(-1, q_group_size)
assert w.dim() == 2
if zero_point:
max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
else: # we actually never used this
# assert min_val is None
max_val = w.abs().amax(dim=1, keepdim=True)
max_val = max_val.clamp(min=1e-5)
max_int = 2 ** (n_bit - 1) - 1
min_int = -max_int
scales = max_val / max_int
zeros = torch.full_like(scales, fill_value=-min_int)
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0
w = w.reshape(org_w_shape)
return scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
def dump_linear_awq(
weight: torch.Tensor,
bias: torch.Tensor,
w_bit: int,
group_size: int,
zero_point: bool = True
) -> dict[str, torch.Tensor]:
scales, zeros = pseudo_quantize_tensor(weight, w_bit, zero_point, group_size)
print(scales.shape)
print(zeros.shape)
tensors = {}
dtype = weight.dtype
oc, ic = weight.shape
# need scales and zeros info for real quantization
assert scales is not None and zeros is not None
scale_zeros = zeros * scales
pack_num = 32 // w_bit
qscales = torch.zeros(
(
scales.shape[0],
calculate_zeros_width(ic, group_size) * pack_num,
),
dtype=dtype,
device=scales.device,
)
qscales[:, : scales.shape[1]] = scales
# awq_linear.scales = scales.clone().half()
tensors["wscales"] = qscales.transpose(1, 0).contiguous()
if bias is not None:
tensors["bias"] = bias.clone()
if False:
intweight = []
for idx in range(ic):
intweight.append(
torch.round(
(weight.data[:, idx] + scale_zeros[:, idx // group_size])
/ qscales[:, idx // group_size]
).clamp(0, 15 if zero_point else 14).to(torch.int)[:, None]
)
print(intweight[0].shape)
intweight = torch.cat(intweight, dim=1)
print(intweight.shape)
intweight_ref = intweight
# intweight = intweight.t().contiguous()
assert ic % group_size == 0
intweight = weight.reshape(oc, ic // group_size, group_size)
# print(f"{intweight.shape} {scale_zeros[..., None].shape} {qscales[..., None].shape}")
intweight = (intweight + scale_zeros[..., None]) / qscales[..., None]
intweight = intweight.round_()
intweight = intweight.clamp_(0, 15 if zero_point else 14)
intweight = intweight.to(dtype=torch.int32)
intweight = intweight.reshape(oc, ic)
if False:
print(intweight_ref - intweight)
assert not (intweight_ref - intweight != 0).any()
tensors["qweight"] = pack_intweight(
intweight.contiguous(), interleave=4, kstride=64
)
zeros = zeros.to(dtype=torch.int32)
scaled_zeros = torch.zeros_like(qscales)
# scaled_zeros[:, :scales.shape[1]] = -(qscales[:, :scales.shape[1]] * (zeros.to(torch.float32) - 8.0)).to(torch.float16)
scaled_zeros[:, : scales.shape[1]] = -(
qscales[:, : scales.shape[1]] * (zeros.to(torch.float32))
).to(dtype)
tensors["wzeros"] = scaled_zeros.transpose(1, 0).contiguous()
return tensors
import time
import argparse
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
import nunchaku.pipelines.flux
def get_pipe(config: str, dev: bool) -> FluxPipeline:
version = "dev" if dev else "schnell"
dtype = torch.bfloat16
qencoder_path = "/NFS/raid0/user/zhangzk/models/flux-t5-tinychat-v2.pt"
if config.startswith("svdq"):
pipe = nunchaku.pipelines.flux.from_pretrained(
f"black-forest-labs/FLUX.1-{version}",
torch_dtype=dtype,
qmodel_path=f"/NFS/raid0/user/zhangzk/models/flux{'-dev' if dev else ''}-svdq-19-38-divsmooth-shift-ada-bf16.safetensors",
qencoder_path=qencoder_path if config == "svdq-t5" else None
)
elif config.startswith("w4a4"):
pipe = nunchaku.pipelines.flux.from_pretrained(
f"black-forest-labs/FLUX.1-{version}",
torch_dtype=dtype,
qmodel_path=f"/NFS/raid0/user/zhangzk/models/flux{'-dev' if dev else ''}-divsmooth-shift-ada-bf16.safetensors",
qencoder_path=qencoder_path if config == "w4a4-t5" else None
)
elif config.startswith("bf16"):
pipe = FluxPipeline.from_pretrained(
f"black-forest-labs/FLUX.1-{version}",
torch_dtype=dtype,
)
if config == "bf16-t5":
nunchaku.pipelines.flux.quantize_t5(pipe, qencoder_path)
elif config.startswith("nf4"):
from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
from accelerate import init_empty_weights
from convert_nf4_flux import _replace_with_bnb_linear, create_quantized_param, check_quantized_param
converted_state_dict = torch.load(f"/NFS/raid0/user/zhangzk/models/flux1-{version}-nf4.pt")
with init_empty_weights():
config = FluxTransformer2DModel.load_config(f"black-forest-labs/flux.1-{version}", subfolder="transformer")
model = FluxTransformer2DModel.from_config(config).to(dtype)
_replace_with_bnb_linear(model, "nf4")
for param_name, param in converted_state_dict.items():
param = param.to(dtype)
print(f"{param_name}: {param.shape} check_quantized_param={check_quantized_param(model, param_name)}")
if not check_quantized_param(model, param_name):
set_module_tensor_to_device(model, param_name, device=0, value=param)
else:
create_quantized_param(model, param, param_name, target_device=0)
pipe = FluxPipeline.from_pretrained(f"black-forest-labs/flux.1-{version}", transformer=model, torch_dtype=dtype)
if config == "nf4-t5":
nunchaku.pipelines.flux.quantize_t5(pipe, qencoder_path)
else:
raise NotImplementedError
return pipe
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="svdq", choices=["svdq", "svdq-t5", "w4a4", "w4a4-t5", "bf16", "bf16-t5", "nf4", "nf4-t5"])
parser.add_argument("--offload", type=int, default=0)
parser.add_argument("--steps", type=int, default=50)
parser.add_argument("--dev", action="store_true")
parser.add_argument("--torchao", action="store_true")
parser.add_argument("--compile", action="store_true")
args = parser.parse_args()
print(f"Use config {args.config}")
if args.offload > 0:
print(f"Use offloading level {args.offload}")
pipe = get_pipe(args.config, args.dev)
print(pipe)
if args.torchao:
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
# pipe.transformer = autoquant(pipe.transformer, error_on_unseen=False)
quantize_(pipe.transformer, int8_dynamic_activation_int8_weight())
if args.offload == 2:
pipe.enable_sequential_cpu_offload()
elif args.offload == 1:
pipe.enable_model_cpu_offload()
elif args.offload == 0:
pipe.to("cuda:0")
else:
raise NotImplementedError
# assert isinstance(pipe, FluxPipeline)
if args.compile:
pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(
pipe.transformer, mode="max-autotune", fullgraph=True
)
prompt = "A cat holding a sign that says hello world"
print(f"Using prompt '{prompt}'")
print(f"Run {args.steps} steps")
latencies = []
for i in range(5):
start_time = time.time()
out = pipe(
prompt=prompt,
guidance_scale=0,
num_inference_steps=args.steps,
generator=torch.Generator(device="cpu").manual_seed(233),
).images[0]
end_time = time.time()
latencies.append(end_time - start_time)
torch.cuda.empty_cache()
latencies = sorted(latencies)
latencies = latencies[1:-1]
out.save("output.png")
print(f"Elapsed: {sum(latencies) / len(latencies)} seconds")
print(f"Torch max_memory_allocated={torch.cuda.max_memory_allocated()}")
\ No newline at end of file
import time
import torch
import diffusers
from diffusers import FluxPipeline
import nunchaku.pipelines.flux
if __name__ == "__main__":
QUANT = False
SEED = 1
DEV = True
LORA_NAME = "anime"
pipe = nunchaku.pipelines.flux.from_pretrained(
f"black-forest-labs/FLUX.1-{'dev' if DEV else 'schnell'}",
torch_dtype=torch.bfloat16,
qmodel_path=f"/NFS/raid0/user/zhangzk/models/flux{'-dev' if DEV else ''}-svdq-19-38-divsmooth-shift-ada-bf16.safetensors",
qencoder_path="/NFS/raid0/user/zhangzk/models/flux-t5-tinychat-v2.pt" if QUANT else None,
)
if LORA_NAME:
pipe.transformer.nunchaku_update_params(f"/tmp/flux-lora-{LORA_NAME}-bf16.safetensors")
pipe.transformer.nunchaku_set_lora_scale(0.4)
print("Moving model to CUDA")
pipe.to("cuda:0")
print("Done")
# prompt = "A cat holding a sign that says hello world"
# prompt = "A cyberpunk cat holding a huge neon sign that says \"SVDQuant is lite and fast\""
prompt = "girl, neck tuft, white hair ,sheep horns, blue eyes, nm22 style"
# prompt = "GHIBSKY style, the most beautiful place in the universe"
# prompt = "the joker, yarn art style"
print(f"Using prompt '{prompt}'")
latencies = []
diffusers.training_utils.set_seed(SEED)
start_time = time.time()
out = pipe(
prompt=prompt,
guidance_scale=3.5 if DEV else 0,
num_inference_steps=50 if DEV else 4,
generator=torch.Generator(device="cpu").manual_seed(SEED),
).images[0]
end_time = time.time()
latencies.append(end_time - start_time)
out.save(f"output{'-dev' if DEV else ''}-{SEED}-{'quant' if QUANT else 'noquant'}.png")
print(f"Elapsed: {sum(latencies) / len(latencies)} seconds")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment