Commit 37c494a7 authored by Zhekai Zhang's avatar Zhekai Zhang
Browse files

Initial release

parents
"""
Utilities adapted from
* https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_bnb_4bit.py
* https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/bitsandbytes.py
"""
import torch
import bitsandbytes as bnb
from transformers.quantizers.quantizers_utils import get_module_from_name
import torch.nn as nn
from accelerate import init_empty_weights
def _replace_with_bnb_linear(
model,
method="nf4",
has_been_replaced=False,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
for name, module in model.named_children():
if isinstance(module, nn.Linear):
with init_empty_weights():
in_features = module.in_features
out_features = module.out_features
if method == "llm_int8":
model._modules[name] = bnb.nn.Linear8bitLt(
in_features,
out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=6.0,
)
has_been_replaced = True
else:
model._modules[name] = bnb.nn.Linear4bit(
in_features,
out_features,
module.bias is not None,
compute_dtype=torch.bfloat16,
compress_statistics=False,
quant_type="nf4",
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_bnb_linear(
module,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
return model, has_been_replaced
def check_quantized_param(
model,
param_name: str,
) -> bool:
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
# Add here check for loaded components' dtypes once serialization is implemented
return True
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
# bias could be loaded by regular set_module_tensor_to_device() from accelerate,
# but it would wrongly use uninitialized weight there.
return True
else:
return False
def create_quantized_param(
model,
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict=None,
unexpected_keys=None,
pre_quantized=False
):
module, tensor_name = get_module_from_name(model, param_name)
if tensor_name not in module._parameters:
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
old_value = getattr(module, tensor_name)
if tensor_name == "bias":
if param_value is None:
new_value = old_value.to(target_device)
else:
new_value = param_value.to(target_device)
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
module._parameters[tensor_name] = new_value
return
if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
raise ValueError("this function only loads `Linear4bit components`")
if (
old_value.device == torch.device("meta")
and target_device not in ["meta", torch.device("meta")]
and param_value is None
):
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
if pre_quantized:
if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (
param_name + ".quant_state.bitsandbytes__nf4" not in state_dict
):
raise ValueError(
f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components."
)
quantized_stats = {}
for k, v in state_dict.items():
# `startswith` to counter for edge cases where `param_name`
# substring can be present in multiple places in the `state_dict`
if param_name + "." in k and k.startswith(param_name):
quantized_stats[k] = v
if unexpected_keys is not None and k in unexpected_keys:
unexpected_keys.remove(k)
new_value = bnb.nn.Params4bit.from_prequantized(
data=param_value,
quantized_stats=quantized_stats,
requires_grad=False,
device=target_device,
)
else:
new_value = param_value.to("cpu")
kwargs = old_value.__dict__
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
print(f"{param_name}: new_value.quant_type={new_value.quant_type} quant_state={new_value.quant_state} storage={new_value.quant_storage} blocksize={new_value.blocksize}")
state = new_value.quant_state
print(f" -- state.code={state.code} dtype={state.dtype} blocksize={state.blocksize}")
module._parameters[tensor_name] = new_value
# generate.py
# from huggingface_hub import hf_hub_download
# from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
# from accelerate import init_empty_weights
# from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
# from convert_nf4_flux import _replace_with_bnb_linear, create_quantized_param, check_quantized_param
# from diffusers import FluxTransformer2DModel, FluxPipeline
# import safetensors.torch
# import gc
# import torch
# dtype = torch.bfloat16
# ckpt_path = hf_hub_download("black-forest-labs/flux.1-dev", filename="flux1-dev.safetensors")
# original_state_dict = safetensors.torch.load_file(ckpt_path)
# converted_state_dict = convert_flux_transformer_checkpoint_to_diffusers(original_state_dict)
# del original_state_dict
# gc.collect()
# with init_empty_weights():
# config = FluxTransformer2DModel.load_config("black-forest-labs/flux.1-dev", subfolder="transformer")
# model = FluxTransformer2DModel.from_config(config).to(dtype)
# _replace_with_bnb_linear(model, "nf4")
# for param_name, param in converted_state_dict.items():
# param = param.to(dtype)
# if not check_quantized_param(model, param_name):
# set_module_tensor_to_device(model, param_name, device=0, value=param)
# else:
# create_quantized_param(model, param, param_name, target_device=0)
# del converted_state_dict
# gc.collect()
# print(compute_module_sizes(model)[""] / 1024 / 1204)
# pipe = FluxPipeline.from_pretrained("black-forest-labs/flux.1-dev", transformer=model, torch_dtype=dtype)
# pipe.enable_model_cpu_offload()
# prompt = "A mystic cat with a sign that says hello world!"
# image = pipe(prompt, guidance_scale=3.5, num_inference_steps=50, generator=torch.manual_seed(0)).images[0]
# image.save("flux-nf4-dev.png")
# model.push_to_hub("sayakpaul/flux.1-dev-nf4")
\ No newline at end of file
This diff is collapsed.
import torch
import safetensors
import torch.nn.functional as F
from dump_flux import DeepCompressorModel, TensorDict, pack_wscales, pack_lora, merge_dict, unsmooth
from typing import Optional
Lora = tuple[torch.Tensor, torch.Tensor]
def load_svdq_lora(path: str, lora_path: str) -> DeepCompressorModel:
result = DeepCompressorModel(
model=torch.load(f"{path}/model.pt", map_location="cpu"),
smooth=torch.load(f"{path}/smooth.pt", map_location="cpu"),
branch=torch.load(f"{path}/branch.pt", map_location="cpu"),
lora={}
)
with safetensors.safe_open(lora_path, framework="pt", device="cpu") as f:
for k in f.keys():
prefix = "transformer."
if k.startswith(prefix):
result.lora[k.removeprefix(prefix)] = f.get_tensor(k)
dtype = next(iter(result.branch.values()))["a.weight"].dtype
for k, v in result.lora.items():
if v.dtype != dtype:
print(f"Convert lora weight {k} from {v.dtype} to {dtype}")
result.lora[k] = v.to(dtype)
# for k, v in result.lora.items():
# v.fill_(0)
return result
# q/k/v [3072, ...] -> qkv [3072 * 3, ...]
def extend_qkv(input: torch.Tensor, id: int) -> torch.Tensor:
oc, ic = input.shape
tmp = torch.zeros([oc * 3, ic], dtype=input.dtype, device=input.device)
tmp[id*oc:(id+1)*oc, ...] = input
return tmp
def merge_lora(inputs: list[Lora]) -> Optional[Lora]:
if len(inputs) == 0:
return None
lora_downs = [x[0] for x in inputs]
lora_ups = [x[1] for x in inputs]
lora_down = torch.cat(lora_downs, dim=0)
lora_up = torch.cat(lora_ups, dim=1)
return (lora_down, lora_up)
def merge_lora_qkv(inputs: list[Lora]) -> list[Lora]:
if len(inputs) == 0:
return []
for x in inputs:
if not x[0].equal(inputs[0][0]):
return inputs
lora_down = inputs[0][0]
lora_ups = [x[1] for x in inputs]
lora_up = torch.sum(torch.stack(lora_ups), dim=0).to(lora_down.dtype)
return [(lora_down, lora_up)]
def dump_lora(lora_down: Optional[torch.Tensor], lora_up: Optional[torch.Tensor]) -> TensorDict:
if lora_down is None:
return {}
rank, ic = lora_down.shape
oc = lora_up.shape[0]
assert lora_up.shape == (oc, rank)
if rank % 16 != 0:
rank_pad = (rank + 16 - 1) // 16 * 16
tmp_down = torch.zeros([rank_pad, ic], dtype=lora_down.dtype, device=lora_down.device)
tmp_up = torch.zeros([oc, rank_pad], dtype=lora_down.dtype, device=lora_down.device)
tmp_down[:rank, ...] = lora_down
tmp_up[..., :rank] = lora_up
lora_down = tmp_down
lora_up = tmp_up
print(f"Pad lora rank from {rank} to {rank_pad}")
lora_down = pack_lora(lora_down.transpose(0, 1), is_lora_down=True)
lora_up = pack_lora(lora_up, is_lora_down=False)
tensors = {}
tensors["lora_down"] = lora_down
tensors["lora_up"] = lora_up
return tensors
def get_original_lora(qmodel: DeepCompressorModel, key_branch: str, key_smooth: Optional[str]) -> Lora:
dtype = qmodel.branch[key_branch]["a.weight"].dtype
smooth = qmodel.smooth[key_smooth].to(dtype).float() if key_smooth else None
return (
unsmooth(qmodel.branch[key_branch]["a.weight"], smooth),
qmodel.branch[key_branch]["b.weight"]
)
def dump_linear_lora(
qmodel: DeepCompressorModel,
key_lora: str,
key_branch: str,
key_smooth: str,
shift_bias: bool = False,
key_bias: Optional[str] = None,
range_ic: slice = slice(None, None, None)) -> TensorDict:
lora_original = get_original_lora(qmodel, key_branch, key_smooth)
if f"{key_lora}.lora_A.weight" in qmodel.lora:
# lora_down = qmodel.lora[f"{key}.lora_A.weight"][..., range_ic]
# lora_up = qmodel.lora[f"{key}.lora_B.weight"]
lora_new = (
qmodel.lora[f"{key_lora}.lora_A.weight"][..., range_ic],
qmodel.lora[f"{key_lora}.lora_B.weight"]
)
lora_down, lora_up = merge_lora([lora_original, lora_new])
rank, ic = lora_down.shape
oc = lora_up.shape[0]
assert lora_up.shape == (oc, rank)
print(f"linear at {key_lora} has rank {rank}")
tensors = dump_lora(lora_down, lora_up)
if shift_bias and False: # no longer need shift bias
if key_bias is None:
key_bias = f"{key_branch}.bias"
if key_bias in qmodel.model:
bias = qmodel.model[key_bias]
print(f"linear at {key_lora} apply shift_bias from original bias at {key_bias}")
else:
bias = torch.zeros([oc], dtype=lora_up.dtype, device=lora_up.device)
print(f"linear at {key_lora} apply shift_bias from empty original bias")
shift = torch.empty([ic], dtype=lora_down.dtype, device=lora_down.device)
shift = shift.fill_(0.171875)
delta = F.linear(F.linear(shift, lora_new[0]), lora_new[1])
print(f"shift_bias delta = {delta}")
bias -= delta
tensors["bias"] = pack_wscales(bias[..., None])[0]
return tensors
else:
print(f"linear at {key_lora} use original lora")
return dump_lora(*lora_original)
def dump_qkv_proj_svdq_lora(
qmodel: DeepCompressorModel,
key_qkv: tuple[str, str, str],
key_smooth: str,
key_smooth_out: str
) -> TensorDict:
dtype = qmodel.branch[key_smooth]["a.weight"].dtype
smooth_out = qmodel.smooth[key_smooth_out].to(dtype).float()
lora_original = get_original_lora(qmodel, key_smooth, key_smooth)
loras = []
for i in range(3):
key = key_qkv[i]
if f"{key}.lora_A.weight" in qmodel.lora:
lora_down = qmodel.lora[f"{key}.lora_A.weight"]
lora_up = qmodel.lora[f"{key}.lora_B.weight"]
if i == 2:
lora_up = (lora_up / smooth_out[..., None]).to(lora_up.dtype)
loras.append((lora_down, extend_qkv(lora_up, i)))
# print(loras)
lora_down, lora_up = merge_lora([lora_original, *merge_lora_qkv(loras)])
print(f"qkv_proj at {key_smooth} has rank {lora_down.shape[0]}")
return dump_lora(lora_down, lora_up)
def dump_transformer_svdq_lora(qmodel: DeepCompressorModel, layer_id: int) -> TensorDict:
tensors = {}
def reorder_adanorm_linear(weight: torch.Tensor) -> torch.Tensor:
oc, ic = weight.shape
assert oc % 6 == 0
return weight.reshape(6, oc // 6, ic).transpose(0, 1).reshape(oc, ic).contiguous()
def linear(key: str, **kwargs):
key_lora = key
key_branch = kwargs.pop("key_branch", key_lora)
key_smooth = kwargs.pop("key_smooth", key_branch)
return dump_linear_lora(qmodel, key_lora, key_branch, key_smooth, **kwargs)
prefix = f"transformer_blocks.{layer_id}"
if f"{prefix}.norm1.linear.lora_A.weight" in qmodel.lora:
lora_down = qmodel.lora[f"{prefix}.norm1.linear.lora_A.weight"]
lora_up = qmodel.lora[f"{prefix}.norm1.linear.lora_B.weight"]
tensors[f"norm1.linear.lora_down"] = lora_down
tensors[f"norm1.linear.lora_up"] = reorder_adanorm_linear(lora_up)
if f"{prefix}.norm1_context.linear.lora_A.weight" in qmodel.lora:
lora_down = qmodel.lora[f"{prefix}.norm1_context.linear.lora_A.weight"]
lora_up = qmodel.lora[f"{prefix}.norm1_context.linear.lora_B.weight"]
tensors[f"norm1_context.linear.lora_down"] = lora_down
tensors[f"norm1_context.linear.lora_up"] = reorder_adanorm_linear(lora_up)
merge_dict(tensors, dump_qkv_proj_svdq_lora(
qmodel,
(f"{prefix}.attn.to_q", f"{prefix}.attn.to_k", f"{prefix}.attn.to_v"),
f"{prefix}.attn.to_q",
f"{prefix}.attn.to_out.0"
), "qkv_proj.")
merge_dict(tensors, dump_qkv_proj_svdq_lora(
qmodel,
(f"{prefix}.attn.add_q_proj", f"{prefix}.attn.add_k_proj", f"{prefix}.attn.add_v_proj"),
f"{prefix}.attn.add_k_proj",
f"{prefix}.attn.to_out.0"
), "qkv_proj_context.")
merge_dict(tensors, linear(f"{prefix}.attn.to_out.0", key_smooth=None), "out_proj.")
merge_dict(tensors, linear(f"{prefix}.attn.to_add_out", key_smooth=None), "out_proj_context.")
merge_dict(tensors, linear(f"{prefix}.ff.net.0.proj"), "mlp_fc1.")
merge_dict(tensors, linear(f"{prefix}.ff.net.2", key_branch=f"{prefix}.ff.net.2.linear", shift_bias=True), "mlp_fc2.")
merge_dict(tensors, linear(f"{prefix}.ff_context.net.0.proj"), "mlp_context_fc1.")
merge_dict(tensors, linear(f"{prefix}.ff_context.net.2", key_branch=f"{prefix}.ff_context.net.2.linear", shift_bias=True), "mlp_context_fc2.")
return tensors
def dump_single_transformer_svdq_lora(qmodel: DeepCompressorModel, layer_id: int) -> TensorDict:
tensors = {}
def reorder_adanorm_linear(weight: torch.Tensor) -> torch.Tensor:
oc, ic = weight.shape
assert oc % 3 == 0
return weight.reshape(3, oc // 3, ic).transpose(0, 1).reshape(oc, ic).contiguous()
def linear(key: str, **kwargs):
key_lora = key
key_branch = kwargs.pop("key_branch", key_lora)
key_smooth = kwargs.pop("key_smooth", key_branch)
return dump_linear_lora(qmodel, key_lora, key_branch, key_smooth, **kwargs)
prefix = f"single_transformer_blocks.{layer_id}"
if f"{prefix}.norm.linear.lora_A.weight" in qmodel.lora:
lora_down = qmodel.lora[f"{prefix}.norm.linear.lora_A.weight"]
lora_up = qmodel.lora[f"{prefix}.norm.linear.lora_B.weight"]
tensors[f"norm.linear.lora_down"] = lora_down
tensors[f"norm.linear.lora_up"] = reorder_adanorm_linear(lora_up)
merge_dict(tensors, dump_qkv_proj_svdq_lora(
qmodel,
(f"{prefix}.attn.to_q", f"{prefix}.attn.to_k", f"{prefix}.attn.to_v"),
f"{prefix}.attn.to_q",
f"{prefix}.proj_out.linears.0"
), "qkv_proj.")
merge_dict(tensors, linear(f"{prefix}.proj_mlp", key_smooth=f"{prefix}.attn.to_q"), "mlp_fc1.")
# TODO
out_dim = 3072
merge_dict(tensors, linear(f"{prefix}.proj_out",
key_branch=f"{prefix}.proj_out.linears.0",
key_smooth=None,
range_ic=slice(0, out_dim)), "out_proj.")
merge_dict(tensors, linear(f"{prefix}.proj_out",
key_branch=f"{prefix}.proj_out.linears.1.linear",
shift_bias=True,
range_ic=slice(out_dim, None)), "mlp_fc2.")
return tensors
@torch.inference_mode()
def dump_flux_svdq_lora(qmodel: DeepCompressorModel, **kwargs) -> TensorDict:
tensors = {}
for i in range(19):
merge_dict(tensors, dump_transformer_svdq_lora(qmodel, i, **kwargs), f"transformer_blocks.{i}.")
for i in range(38):
merge_dict(tensors, dump_single_transformer_svdq_lora(qmodel, i, **kwargs), f"single_transformer_blocks.{i}.")
return tensors
if __name__ == "__main__":
lora_name = "realism"
if lora_name == "sketch":
qmodel = load_svdq_lora("model-dev", "../third_party/FLUX.1-dev-LoRA-Collections/sketch.safetensors")
elif lora_name == "realism":
qmodel = load_svdq_lora("model-dev", "../third_party/FLUX.1-dev-LoRA-Collections/realism.safetensors")
elif lora_name == "anime":
qmodel = load_svdq_lora("model-dev", "../third_party/sonny-anime-fixed/araminta_k_sonnyanime_fluxd_fixed.safetensors")
elif lora_name == "ghibsky":
qmodel = load_svdq_lora("model-dev", "../third_party/flux-ghibsky-illustration/lora.safetensors")
elif lora_name == "yarn":
qmodel = load_svdq_lora("model-dev", "../third_party/yarn_art_Flux_LoRA/pytorch_lora_weights.safetensors")
elif lora_name == "sketch2image":
qmodel = load_svdq_lora("model-dev", "sketch2image.safetensors")
else:
raise NotImplementedError
tensors = dump_flux_svdq_lora(qmodel)
for k, v in tensors.items():
assert not v.isnan().any()
assert not v.isinf().any()
safetensors.torch.save_file(
tensors,
f"/tmp/flux-lora-{lora_name}-bf16.safetensors")
#!/bin/bash
rundir=$(date +"run-$(hostname -s)-%Y%m%d-%H%M%S")
mkdir -p $rundir
function run() {
echo config=$config
echo args=$@
python3 run_flux.py --steps 4 "$@" > >(tee $rundir/stdout-s4-$config.log) 2> >(tee $rundir/stderr-s4-$config.log)
python3 run_flux.py --steps 25 "$@" > >(tee $rundir/stdout-s25-$config.log) 2> >(tee $rundir/stderr-s25-$config.log)
python3 run_flux.py --steps 50 "$@" > >(tee $rundir/stdout-s50-$config.log) 2> >(tee $rundir/stderr-s50-$config.log)
if [ $? -eq 0 ]; then
nsys profile --cuda-memory-usage true -o $rundir/report-$config.nsys-rep python3 run_flux.py --steps 4 "$@"
fi
}
config=bf16-compile
run --config bf16 --compile
config=bf16-t5-compile
run --config bf16-t5 --compile
config=int8dq-compile
run --config bf16 --torchao --compile
config=int8dq-t5-compile
run --config bf16-t5 --torchao --compile
config=int8dq-nocompile
run --config bf16 --torchao
config=int8dq-t5-nocompile
run --config bf16-t5 --torchao
for cfg in svdq svdq-t5 w4a4 w4a4-t5 bf16 bf16-t5 nf4 nf4-t5; do
config=$cfg
run --config $cfg
config=$cfg-ol1
run --config $cfg --offload 1
config=$cfg-ol2
run --config $cfg --offload 2
done
\ No newline at end of file
import torch
from torch.nn import functional as F
from dump_flux import group_scale
def compare(
ref: torch.Tensor,
v: torch.Tensor,
refname: str,
vname: str,
list_diff: bool = False):
print(f"== COMPARE v={vname} vs ref={refname}")
diff = v - ref
print(f" - diff = {diff}")
if list_diff:
print(f" - diffs at {diff.nonzero()}")
mse = diff.square().mean()
print(f" - mse = {mse}")
nmse = mse / ref.square().mean()
print(f" - nmse = {nmse}")
print(f" - mean(v/ref)={v.mean()}/{ref.mean()}")
print(f" - var(v/ref)={v.var()}/{ref.var()}")
print(f"== ")
print()
def print_debug_results(debug_results: dict[str, torch.Tensor], is_ref: bool = False):
ref = 'REF' if is_ref else ''
for k, v in debug_results.items():
has_nan = v.isnan().any()
has_inf = v.isinf().any()
if v.dtype.is_floating_point:
print(f" {ref} {k}: {v.shape} ({v.dtype}) has_nan={has_nan} has_inf={has_inf} max={v.max()} min={v.min()} mean={v.mean()} var={v.var()}")
else:
print(f" {ref} {k}: {v.shape} ({v.dtype})")
if has_nan:
cnt = v.isnan().count_nonzero()
print(f" {ref} -- {cnt} ({cnt / v.numel() * 100}%) nans at {v.isnan().nonzero()[0:10]}")
if has_inf:
cnt = v.isinf().count_nonzero()
print(f" {ref} -- {cnt} ({cnt / v.numel() * 100}%) infs at {v.isinf().nonzero()[0:10]}")
print(f" {ref} -- {v}")
print()
def fakequant(
act: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
group_size: int = 64,
force_cuda: bool = False,
):
oc, ic = weight.shape
batch_size = act.shape[0]
assert act.shape[1] == ic
# [oc, ic // group_size]
wscales = group_scale(weight, num_bits=4, group_size=group_size)
qweight = weight.reshape(oc, ic // group_size, group_size).to(dtype=torch.float32) / wscales[..., None]
qweight = qweight.round().clamp(-8, 7)
qweight_i = qweight.int()
qweight = qweight * wscales[..., None]
qweight = qweight.to(weight.dtype)
qweight = qweight.reshape(oc, ic)
# print(f"qweight = {qweight}")
print_debug_results({"qweight": qweight})
# [batch_size, ic // group_size]
ascales = group_scale(act, num_bits=4, group_size=group_size).to(dtype=weight.dtype)
qact = act.reshape(batch_size, ic // group_size, group_size).to(dtype=torch.float32) / ascales[..., None]
qact = qact.round().clamp(-8, 7)
qact_i = qact.int()
print_debug_results({"qact_i": qact_i})
qact = qact * ascales[..., None]
qact = qact.to(act.dtype)
qact = qact.reshape(batch_size, ic)
# print(f"qact = {qact}")
print_debug_results({"qact": qact})
outref_q = F.linear(qact.to(qweight.dtype), qweight, bias)
# print(f"outref_q = {outref_q}")
print_debug_results({"outref_q": outref_q})
###
if force_cuda:
qweight_i = qweight_i.to("cuda")
qact_i = qact_i.to("cuda")
wscales = wscales.to("cuda")
ascales = ascales.to("cuda")
bias = bias.to("cuda")
qweight = qweight_i
qact = qact_i
qweight = qweight.reshape(oc, ic // group_size, group_size).transpose(0, 1).transpose(1, 2)
qact = qact.reshape(batch_size, ic // group_size, group_size).transpose(0, 1)
# [ic // group_size, batch_size, oc]
psum = torch.bmm(qact.float(), qweight.float())
print(f"psum_i ({psum.shape}) = {psum} ")
# print(psum[:, 0, 23])
# print(f"ascales = {ascales}")
print_debug_results({"ascales": ascales})
print(f"ascales[0:16] = {ascales[0:16, 0]}")
ws1 = wscales.transpose(0, 1).reshape(ic // group_size, 1, oc).repeat(1, batch_size, 1)
as1 = ascales.transpose(0, 1).reshape(ic // group_size, batch_size, 1).repeat(1, 1, oc)
scales = ws1 * as1
print(f"scales = {scales}")
# print(scales[:, 0, 23])
psum = psum.to(dtype=act.dtype).float()
psum = psum * scales
print(f"psum ({psum.shape}) = {psum} ")
# print(psum[:, 0, 23])
# outref_q2 = psum.sum(dim=0) # .to(layer.weight.dtype)
outref_q2 = torch.zeros_like(psum[0])
for i in range(psum.shape[0]):
outref_q2 = (outref_q2 + psum[i]).to(act.dtype)
outref_q2 += bias[None, ...]
# print(f"outref_q2 = {outref_q2}")
print_debug_results({"outref_q2": outref_q2})
# print(outref_q2[0, 23])
if force_cuda:
outref_q2 = outref_q2.to(act.device)
return outref_q, outref_q2
from safetensors.torch import safe_open, save_file
def main():
input_path1 = "app/i2i/pretrained/converted/sketch.safetensors"
input_path2 = "app/i2i/pretrained/original/flux-lora-sketch2image-bf16.safetensors"
sd1 = {}
with safe_open(input_path1, framework="pt") as f:
for k in f.keys():
sd1[k] = f.get_tensor(k)
sd2 = {}
with safe_open(input_path2, framework="pt") as f:
for k in f.keys():
sd2[k] = f.get_tensor(k)
for k in sd1.keys():
if "lora" not in k:
print(k)
sd2[k.replace("transformer.", "")] = sd1[k]
save_file(sd2, "svdq-flux.1-pix2pix-turbo-sketch2image.safetensors")
if __name__ == "__main__":
main()
import math
import torch
import torch.nn as nn
def make_divisible(c, divisor):
return (c + divisor - 1) // divisor
def calculate_zeros_width(in_features, group_size=128, pack_num=8):
if group_size >= 128:
size_multiplier = 1
elif group_size == 64:
size_multiplier = 2
elif group_size == 32:
size_multiplier = 4
else:
raise NotImplementedError
base_width = make_divisible(in_features // group_size, pack_num)
base_width = make_divisible(base_width, size_multiplier) * size_multiplier
return base_width
def pack_intweight(unpacked_qweight, interleave, kstride):
# unpacked_qweight: [N, K]
N = unpacked_qweight.shape[0]
K = unpacked_qweight.shape[1]
Packed_Kernel = unpacked_qweight.cpu().numpy().reshape(N, K // 32, 32)
# np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...]
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4)
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 32)
# reorder each 8 weights for fast dequantization
# [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7]
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 8)
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3)
Packed_Kernel = Packed_Kernel.reshape(N, K)
# interleaving every four rows
Packed_Kernel = Packed_Kernel.reshape(
N // interleave, interleave, K // kstride, kstride
)
# N // 4, K // 64, 4, 64
Packed_Kernel = Packed_Kernel.transpose(0, 2, 1, 3)
Packed_Kernel = Packed_Kernel.reshape(
N // interleave, K // kstride, kstride, interleave
)
# Packing -> (N // 4, K // 64, 64)
Packed_Kernel = (
Packed_Kernel[..., 0]
| (Packed_Kernel[..., 1] << 4)
| (Packed_Kernel[..., 2] << 8)
| (Packed_Kernel[..., 3] << 12)
)
# reshape to (N // 4, K), FP16 format
Packed_Kernel = Packed_Kernel.reshape(N // interleave, K)
qweight = (
torch.tensor(Packed_Kernel.astype("int16"))
.to(unpacked_qweight.device)
.contiguous()
)
return qweight
def pseudo_quantize_tensor(
w, n_bit=8, zero_point=True, q_group_size=-1,
) -> tuple[torch.Tensor, torch.Tensor]:
org_w_shape = w.shape
if q_group_size > 0:
assert org_w_shape[-1] % q_group_size == 0
w = w.reshape(-1, q_group_size)
assert w.dim() == 2
if zero_point:
max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
else: # we actually never used this
# assert min_val is None
max_val = w.abs().amax(dim=1, keepdim=True)
max_val = max_val.clamp(min=1e-5)
max_int = 2 ** (n_bit - 1) - 1
min_int = -max_int
scales = max_val / max_int
zeros = torch.full_like(scales, fill_value=-min_int)
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0
w = w.reshape(org_w_shape)
return scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
def dump_linear_awq(
weight: torch.Tensor,
bias: torch.Tensor,
w_bit: int,
group_size: int,
zero_point: bool = True
) -> dict[str, torch.Tensor]:
scales, zeros = pseudo_quantize_tensor(weight, w_bit, zero_point, group_size)
print(scales.shape)
print(zeros.shape)
tensors = {}
dtype = weight.dtype
oc, ic = weight.shape
# need scales and zeros info for real quantization
assert scales is not None and zeros is not None
scale_zeros = zeros * scales
pack_num = 32 // w_bit
qscales = torch.zeros(
(
scales.shape[0],
calculate_zeros_width(ic, group_size) * pack_num,
),
dtype=dtype,
device=scales.device,
)
qscales[:, : scales.shape[1]] = scales
# awq_linear.scales = scales.clone().half()
tensors["wscales"] = qscales.transpose(1, 0).contiguous()
if bias is not None:
tensors["bias"] = bias.clone()
if False:
intweight = []
for idx in range(ic):
intweight.append(
torch.round(
(weight.data[:, idx] + scale_zeros[:, idx // group_size])
/ qscales[:, idx // group_size]
).clamp(0, 15 if zero_point else 14).to(torch.int)[:, None]
)
print(intweight[0].shape)
intweight = torch.cat(intweight, dim=1)
print(intweight.shape)
intweight_ref = intweight
# intweight = intweight.t().contiguous()
assert ic % group_size == 0
intweight = weight.reshape(oc, ic // group_size, group_size)
# print(f"{intweight.shape} {scale_zeros[..., None].shape} {qscales[..., None].shape}")
intweight = (intweight + scale_zeros[..., None]) / qscales[..., None]
intweight = intweight.round_()
intweight = intweight.clamp_(0, 15 if zero_point else 14)
intweight = intweight.to(dtype=torch.int32)
intweight = intweight.reshape(oc, ic)
if False:
print(intweight_ref - intweight)
assert not (intweight_ref - intweight != 0).any()
tensors["qweight"] = pack_intweight(
intweight.contiguous(), interleave=4, kstride=64
)
zeros = zeros.to(dtype=torch.int32)
scaled_zeros = torch.zeros_like(qscales)
# scaled_zeros[:, :scales.shape[1]] = -(qscales[:, :scales.shape[1]] * (zeros.to(torch.float32) - 8.0)).to(torch.float16)
scaled_zeros[:, : scales.shape[1]] = -(
qscales[:, : scales.shape[1]] * (zeros.to(torch.float32))
).to(dtype)
tensors["wzeros"] = scaled_zeros.transpose(1, 0).contiguous()
return tensors
import time
import argparse
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
import nunchaku.pipelines.flux
def get_pipe(config: str, dev: bool) -> FluxPipeline:
version = "dev" if dev else "schnell"
dtype = torch.bfloat16
qencoder_path = "/NFS/raid0/user/zhangzk/models/flux-t5-tinychat-v2.pt"
if config.startswith("svdq"):
pipe = nunchaku.pipelines.flux.from_pretrained(
f"black-forest-labs/FLUX.1-{version}",
torch_dtype=dtype,
qmodel_path=f"/NFS/raid0/user/zhangzk/models/flux{'-dev' if dev else ''}-svdq-19-38-divsmooth-shift-ada-bf16.safetensors",
qencoder_path=qencoder_path if config == "svdq-t5" else None
)
elif config.startswith("w4a4"):
pipe = nunchaku.pipelines.flux.from_pretrained(
f"black-forest-labs/FLUX.1-{version}",
torch_dtype=dtype,
qmodel_path=f"/NFS/raid0/user/zhangzk/models/flux{'-dev' if dev else ''}-divsmooth-shift-ada-bf16.safetensors",
qencoder_path=qencoder_path if config == "w4a4-t5" else None
)
elif config.startswith("bf16"):
pipe = FluxPipeline.from_pretrained(
f"black-forest-labs/FLUX.1-{version}",
torch_dtype=dtype,
)
if config == "bf16-t5":
nunchaku.pipelines.flux.quantize_t5(pipe, qencoder_path)
elif config.startswith("nf4"):
from accelerate.utils import set_module_tensor_to_device, compute_module_sizes
from accelerate import init_empty_weights
from convert_nf4_flux import _replace_with_bnb_linear, create_quantized_param, check_quantized_param
converted_state_dict = torch.load(f"/NFS/raid0/user/zhangzk/models/flux1-{version}-nf4.pt")
with init_empty_weights():
config = FluxTransformer2DModel.load_config(f"black-forest-labs/flux.1-{version}", subfolder="transformer")
model = FluxTransformer2DModel.from_config(config).to(dtype)
_replace_with_bnb_linear(model, "nf4")
for param_name, param in converted_state_dict.items():
param = param.to(dtype)
print(f"{param_name}: {param.shape} check_quantized_param={check_quantized_param(model, param_name)}")
if not check_quantized_param(model, param_name):
set_module_tensor_to_device(model, param_name, device=0, value=param)
else:
create_quantized_param(model, param, param_name, target_device=0)
pipe = FluxPipeline.from_pretrained(f"black-forest-labs/flux.1-{version}", transformer=model, torch_dtype=dtype)
if config == "nf4-t5":
nunchaku.pipelines.flux.quantize_t5(pipe, qencoder_path)
else:
raise NotImplementedError
return pipe
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="svdq", choices=["svdq", "svdq-t5", "w4a4", "w4a4-t5", "bf16", "bf16-t5", "nf4", "nf4-t5"])
parser.add_argument("--offload", type=int, default=0)
parser.add_argument("--steps", type=int, default=50)
parser.add_argument("--dev", action="store_true")
parser.add_argument("--torchao", action="store_true")
parser.add_argument("--compile", action="store_true")
args = parser.parse_args()
print(f"Use config {args.config}")
if args.offload > 0:
print(f"Use offloading level {args.offload}")
pipe = get_pipe(args.config, args.dev)
print(pipe)
if args.torchao:
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
# pipe.transformer = autoquant(pipe.transformer, error_on_unseen=False)
quantize_(pipe.transformer, int8_dynamic_activation_int8_weight())
if args.offload == 2:
pipe.enable_sequential_cpu_offload()
elif args.offload == 1:
pipe.enable_model_cpu_offload()
elif args.offload == 0:
pipe.to("cuda:0")
else:
raise NotImplementedError
# assert isinstance(pipe, FluxPipeline)
if args.compile:
pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(
pipe.transformer, mode="max-autotune", fullgraph=True
)
prompt = "A cat holding a sign that says hello world"
print(f"Using prompt '{prompt}'")
print(f"Run {args.steps} steps")
latencies = []
for i in range(5):
start_time = time.time()
out = pipe(
prompt=prompt,
guidance_scale=0,
num_inference_steps=args.steps,
generator=torch.Generator(device="cpu").manual_seed(233),
).images[0]
end_time = time.time()
latencies.append(end_time - start_time)
torch.cuda.empty_cache()
latencies = sorted(latencies)
latencies = latencies[1:-1]
out.save("output.png")
print(f"Elapsed: {sum(latencies) / len(latencies)} seconds")
print(f"Torch max_memory_allocated={torch.cuda.max_memory_allocated()}")
\ No newline at end of file
import time
import torch
import diffusers
from diffusers import FluxPipeline
import nunchaku.pipelines.flux
if __name__ == "__main__":
QUANT = False
SEED = 1
DEV = True
LORA_NAME = "anime"
pipe = nunchaku.pipelines.flux.from_pretrained(
f"black-forest-labs/FLUX.1-{'dev' if DEV else 'schnell'}",
torch_dtype=torch.bfloat16,
qmodel_path=f"/NFS/raid0/user/zhangzk/models/flux{'-dev' if DEV else ''}-svdq-19-38-divsmooth-shift-ada-bf16.safetensors",
qencoder_path="/NFS/raid0/user/zhangzk/models/flux-t5-tinychat-v2.pt" if QUANT else None,
)
if LORA_NAME:
pipe.transformer.nunchaku_update_params(f"/tmp/flux-lora-{LORA_NAME}-bf16.safetensors")
pipe.transformer.nunchaku_set_lora_scale(0.4)
print("Moving model to CUDA")
pipe.to("cuda:0")
print("Done")
# prompt = "A cat holding a sign that says hello world"
# prompt = "A cyberpunk cat holding a huge neon sign that says \"SVDQuant is lite and fast\""
prompt = "girl, neck tuft, white hair ,sheep horns, blue eyes, nm22 style"
# prompt = "GHIBSKY style, the most beautiful place in the universe"
# prompt = "the joker, yarn art style"
print(f"Using prompt '{prompt}'")
latencies = []
diffusers.training_utils.set_seed(SEED)
start_time = time.time()
out = pipe(
prompt=prompt,
guidance_scale=3.5 if DEV else 0,
num_inference_steps=50 if DEV else 4,
generator=torch.Generator(device="cpu").manual_seed(SEED),
).images[0]
end_time = time.time()
latencies.append(end_time - start_time)
out.save(f"output{'-dev' if DEV else ''}-{SEED}-{'quant' if QUANT else 'noquant'}.png")
print(f"Elapsed: {sum(latencies) / len(latencies)} seconds")
import torch
from nunchaku.pipelines import flux as nunchaku_flux
pipeline = nunchaku_flux.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdquant-models/svdq-int4-flux.1-schnell.safetensors", # download from Huggingface
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0]
image.save("example.png")
__version__ = "0.0.0beta0"
\ No newline at end of file
#pragma once
#include "interop/torch.h"
#include "FluxModel.h"
#include "Serialization.h"
#include "debug.h"
#include "Linear.h"
class QuantizedFluxModel { // : public torch::CustomClassHolder {
public:
void init(bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedFluxModel");
net = std::make_unique<FluxModel>(bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
void reset() {
debugContext.reset();
net.reset();
Tensor::synchronizeDevice();
trimMemory();
Tensor::synchronizeDevice();
}
void load(std::string path, bool partial = false) {
checkModel();
spdlog::info("{} weights from {}", partial ? "Loading partial" : "Loading", path);
std::shared_ptr<SafeTensors> provider = std::make_shared<SafeTensors>(path);
net->loadParams(*provider, partial);
Tensor::synchronizeDevice();
spdlog::info("Done.");
}
torch::Tensor forward(
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context,
torch::Tensor rotary_emb_single)
{
checkModel();
spdlog::debug("QuantizedFluxModel forward");
hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous();
temb = temb.contiguous();
rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous();
rotary_emb_single = rotary_emb_single.contiguous();
Tensor result = net->forward(
from_torch(hidden_states),
from_torch(encoder_hidden_states),
from_torch(temb),
from_torch(rotary_emb_img),
from_torch(rotary_emb_context),
from_torch(rotary_emb_single)
);
torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice();
return output;
}
std::tuple<torch::Tensor, torch::Tensor> forward_layer(
int64_t idx,
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context)
{
spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous();
temb = temb.contiguous();
rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous();
auto &&[result_img, result_txt] = net->transformer_blocks.at(idx)->forward(
from_torch(hidden_states),
from_torch(encoder_hidden_states),
from_torch(temb),
from_torch(rotary_emb_img),
from_torch(rotary_emb_context),
0.0f
);
hidden_states = to_torch(result_img);
encoder_hidden_states = to_torch(result_txt);
Tensor::synchronizeDevice();
return { hidden_states, encoder_hidden_states };
}
torch::Tensor forward_single_layer(
int64_t idx,
torch::Tensor hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_single)
{
spdlog::debug("QuantizedFluxModel forward_single_layer {}", idx);
hidden_states = hidden_states.contiguous();
temb = temb.contiguous();
rotary_emb_single = rotary_emb_single.contiguous();
Tensor result = net->single_transformer_blocks.at(idx)->forward(
from_torch(hidden_states),
from_torch(temb),
from_torch(rotary_emb_single)
);
hidden_states = to_torch(result);
Tensor::synchronizeDevice();
return hidden_states;
}
void disableMemoryAutoRelease() {
int device;
checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
uint64_t threshold = UINT64_MAX;
checkCUDA(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
}
void trimMemory() {
int device;
checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
size_t bytesToKeep = 0;
checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep));
}
void startDebug() {
debugContext = std::make_unique<DebugContext>();
}
void stopDebug() {
debugContext.reset();
}
auto getDebugResults() {
// c10::Dict<std::string, torch::Tensor> result;
std::map<std::string, torch::Tensor> result;
if (debugContext) {
for (auto &&[key, value] : debugContext->tensors) {
// result.insert(key, to_torch(value));
result[key] = to_torch(value);
}
}
return result;
}
// must be called after loading lora
// skip specific ranks in W4A4 layers
void setLoraScale(int skipRanks, float scale) {
if (skipRanks % 16 != 0) {
throw std::invalid_argument("skipRanks must be multiples of 16");
}
spdlog::info("Set lora scale to {} (skip {} ranks)", scale, skipRanks);
net->traverse([&](Module *module) {
if (auto *m = dynamic_cast<GEMV_AWQ *>(module)) {
m->lora_scale = scale;
} else if (auto *m = dynamic_cast<GEMM_W4A4 *>(module)) {
for (int i = 0; i < skipRanks / 16; i++) {
m->lora_scales[i] = 1.0f;
}
for (int i = skipRanks / 16; i < m->lora_scales.size(); i++) {
m->lora_scales[i] = scale;
}
}
});
}
private:
void checkModel() {
if (!net) {
throw std::runtime_error("Model not initialized");
}
}
private:
std::unique_ptr<FluxModel> net;
std::unique_ptr<DebugContext> debugContext;
};
\ No newline at end of file
#pragma once
#include "interop/torch.h"
#include "Serialization.h"
#include "Linear.h"
#include "debug.h"
#include "kernels/gemm_w4a4.h"
#include "kernels/awq/gemv_awq.h"
class QuantizedGEMM { // : public torch::CustomClassHolder {
public:
void init(int64_t in_features, int64_t out_features, bool bias, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedGEMM");
size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W4A4>((int)in_features, (int)out_features, bias, bf16 ? Tensor::BF16 : Tensor::FP16, Device::cuda((int)deviceId));
}
void reset() {
debugContext.reset();
net.reset();
Tensor::synchronizeDevice();
}
void load(std::string path) {
checkModel();
spdlog::info("Loading weights from {}", path);
std::shared_ptr<SafeTensors> provider = std::make_shared<SafeTensors>(path);
net->loadParams(*provider);
Tensor::synchronizeDevice();
}
torch::Tensor forward(torch::Tensor x) {
checkModel();
std::cerr << "QuantizedGEMM forward" << std::endl;
x = x.contiguous();
Tensor result = std::get<Tensor>(net->forward(
from_torch(x)
));
torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice();
return output;
}
std::string dumpTensorBF16(Tensor x) {
std::stringstream ss;
for (int i = 0; i < 256; i++) {
ss << spdlog::fmt_lib::format("{:.3f} ", (float)(x.data_ptr<__nv_bfloat16>()[i]));
}
ss << std::endl;
return ss.str();
}
std::string dumpTensorINT4(Tensor x) {
using spdlog::fmt_lib::format;
const int M = x.shape[0];
const int K = x.shape[1] * 2;
assert(x.dtype() == Tensor::INT8);
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4)
constexpr int BLOCK_M = 256;
constexpr int WARP_K = 64;
constexpr int NUM_WARPS = 8;
constexpr int WARP_M_TILES = 2;
constexpr int WARP_SIZE = 32;
std::stringstream ss;
for (int bm = 0; bm < M / BLOCK_M; bm++) {
for (int bn = 0; bn < K / WARP_K; bn++) {
for (int warpId = 0; warpId < NUM_WARPS; warpId++) {
ss << format("[bm={},bn={},warp={}] ", bm, bn, warpId);
const int offset = ((bm * (K / WARP_K) + bn) * NUM_WARPS + warpId) * WARP_M_TILES * WARP_SIZE * 4;
for (int i = 0; i < 16; i++) {
assert(offset + i < x.numel() / 4);
uint32_t val = x.data_ptr<uint32_t>()[offset + i];
ss << "{";
for (int j = 0; j < 8; j++) {
int i4val = (val >> (j * 4)) & 0xf;
if (i4val & 0x8) {
i4val = -((~i4val & 0x7) + 1);
}
ss << format("{} ", i4val);
}
ss << format("}} {:x} ", val);
}
ss << std::endl;
}
}
}
ss << std::endl;
return ss.str();
}
void quantize(torch::Tensor x) {
checkModel();
spdlog::debug("QuantizedGEMM quantize");
x = x.contiguous();
auto qout = net->quantize(
from_torch(x)
);
Tensor act = qout.act.copy(Device::cpu());
Tensor ascales = qout.ascales.copy(Device::cpu());
Tensor lora_act = qout.lora_act.copy(Device::cpu());
Tensor::synchronizeDevice();
spdlog::debug("act = {}", dumpTensorINT4(act));
spdlog::debug("ascales = {}", dumpTensorBF16(ascales));
}
void gemm(
c10::optional<torch::Tensor> act, // packed act [M, K / 2]
c10::optional<torch::Tensor> wgt, // packed act [N, K / 2]
c10::optional<torch::Tensor> out, // linear [M, N]
c10::optional<torch::Tensor> qout, // packed act [M, N / 2]
c10::optional<torch::Tensor> ascales, // packed as [K / 64, M]
c10::optional<torch::Tensor> wscales, // packed ws [K / 64, N]
c10::optional<torch::Tensor> oscales, // packed as [N / 64, M]
c10::optional<torch::Tensor> poolout, // linear [M / PoolSize, N]
c10::optional<torch::Tensor> lora_act_in, // packed lora_act [M, R]
c10::optional<torch::Tensor> lora_up, // packed lora_wgt [N, R]
c10::optional<torch::Tensor> lora_down, // packed lora_wgt [N, R]
c10::optional<torch::Tensor> lora_act_out, // packed lora_act [M, R]
c10::optional<torch::Tensor> norm_q, // linear [HEAD_DIM]
c10::optional<torch::Tensor> norm_k, // linear [HEAD_DIM]
c10::optional<torch::Tensor> rotary_emb, // linear [M, HEAD_DIM / 2, 2, 2]
c10::optional<torch::Tensor> bias, // packed ws [N]
c10::optional<torch::Tensor> smooth_factor, // packed ws [N], for quantization of the next layer
bool act_unsigned,
std::vector<float> lora_scales
) {
std::cerr << "running gemm_w4a4: " << std::endl;
auto getTensor = [](c10::optional<torch::Tensor> &t) {
Tensor ret = t.has_value() ? from_torch(t.value()) : Tensor{};
if (ret.valid()) {
std::cerr << " " << ret.shape.str() << std::endl;
} else {
std::cerr << " <invalid>" << std::endl;
}
return ret;
};
gemm_w4a4(
getTensor(act ),
getTensor(wgt ),
getTensor(out ),
getTensor(qout ),
getTensor(ascales ),
getTensor(wscales ),
getTensor(oscales ),
getTensor(poolout ),
getTensor(lora_act_in ),
getTensor(lora_up ),
getTensor(lora_down ),
getTensor(lora_act_out ),
getTensor(norm_q ),
getTensor(norm_k ),
getTensor(rotary_emb ),
getTensor(bias ),
getTensor(smooth_factor),
act_unsigned,
lora_scales
);
Tensor::synchronizeDevice();
}
torch::Tensor gemv_awq(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int64_t m,
int64_t n,
int64_t k,
int64_t group_size)
{
Tensor result = ::gemv_awq(
from_torch(_in_feats.contiguous()),
from_torch(_kernel.contiguous()),
from_torch(_scaling_factors.contiguous()),
from_torch(_zeros.contiguous()),
(int)m,
(int)n,
(int)k,
(int)group_size
);
torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice();
return output;
}
void startDebug() {
debugContext = std::make_unique<DebugContext>();
}
void stopDebug() {
debugContext.reset();
}
auto getDebugResults() {
// c10::Dict<std::string, torch::Tensor> result;
std::map<std::string, torch::Tensor> result;
if (debugContext) {
for (auto &&[key, value] : debugContext->tensors) {
// result.insert(key, to_torch(value));
result[key] = to_torch(value);
}
}
return result;
}
private:
void checkModel() {
if (!net) {
throw std::runtime_error("Model not initialized");
}
}
private:
std::unique_ptr<GEMM_W4A4> net;
std::unique_ptr<DebugContext> debugContext;
};
\ No newline at end of file
#include "gemm.h"
#include "flux.h"
#include <pybind11/pybind11.h>
// TORCH_LIBRARY(diffuxer, m) {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel")
// .def(torch::init<>())
.def(py::init<>())
.def("init", &QuantizedFluxModel::init,
py::arg("bf16"),
py::arg("deviceId")
)
.def("reset", &QuantizedFluxModel::reset)
.def("load", &QuantizedFluxModel::load,
py::arg("path"),
py::arg("partial") = false
)
.def("forward", &QuantizedFluxModel::forward)
.def("forward_layer", &QuantizedFluxModel::forward_layer)
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
.def("disableMemoryAutoRelease", &QuantizedFluxModel::disableMemoryAutoRelease)
.def("trimMemory", &QuantizedFluxModel::trimMemory)
.def("startDebug", &QuantizedFluxModel::startDebug)
.def("stopDebug", &QuantizedFluxModel::stopDebug)
.def("getDebugResults", &QuantizedFluxModel::getDebugResults)
.def("setLoraScale", &QuantizedFluxModel::setLoraScale)
;
py::class_<QuantizedGEMM>(m, "QuantizedGEMM")
// .def(torch::init<>())
.def(py::init<>())
.def("init", &QuantizedGEMM::init)
.def("reset", &QuantizedGEMM::reset)
.def("load", &QuantizedGEMM::load)
.def("forward", &QuantizedGEMM::forward)
.def("quantize", &QuantizedGEMM::quantize)
.def("gemm", &QuantizedGEMM::gemm)
.def("gemv_awq", &QuantizedGEMM::gemv_awq)
.def("startDebug", &QuantizedGEMM::startDebug)
.def("stopDebug", &QuantizedGEMM::stopDebug)
.def("getDebugResults", &QuantizedGEMM::getDebugResults)
;
}
This diff is collapsed.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
safety_check_template = """You are a policy expert trying to help determine whether a user
prompt is in violation of the defined safety policies.
<start_of_turn>
Human Question: {user_prompt}
<end_of_turn>
Our safety principle is defined in the below:
The prompt must not seek to generate harmful, abusive, hateful, sexually explicit, or otherwise inappropriate content targeting individuals or protected groups.
Does the human question violate the above principle? Your answer must start
with 'Yes' or 'No'. And then walk through step by step to be sure we answer
correctly.
"""
class SafetyChecker:
def __init__(self, device: str | torch.device, disabled: bool = False):
if not disabled:
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained("google/shieldgemma-2b")
self.llm = AutoModelForCausalLM.from_pretrained("google/shieldgemma-2b", torch_dtype=torch.bfloat16).to(
device
)
self.disabled = disabled
def __call__(self, user_prompt: str, threshold: float = 0.2) -> bool:
if self.disabled:
return True
device = self.device
inputs = self.tokenizer(safety_check_template.format(user_prompt=user_prompt), return_tensors="pt").to(device)
with torch.no_grad():
logits = self.llm(**inputs).logits
# Extract the logits for the Yes and No tokens
vocab = self.tokenizer.get_vocab()
selected_logits = logits[0, -1, [vocab["Yes"], vocab["No"]]]
# Convert these logits to a probability with softmax
probabilities = torch.softmax(selected_logits, dim=0)
# Return probability of 'Yes'
score = probabilities[0].item()
return score < threshold
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment