Unverified Commit 0ade163c authored by ZIAN HU's avatar ZIAN HU Committed by GitHub
Browse files

feat: upgrade the 4-bit quantized T5 encoder (#320)



* Updating quantized t5 encoder

* Fix formatting based on pre-commit hook

* Update test cases

* Fixing linter issue

* Fix linter reformatting

* support fp4

* style: make linter happy

* update the fp4 lpips

* Prevent downloading original t5 model

* Make sure model in eval mode

---------
Co-authored-by: default avatarmuyangli <lmxyy1999@foxmail.com>
parent 212fd278
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
from nunchaku.utils import get_precision
def main():
pipeline_init_kwargs = {}
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-schnell")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16, **pipeline_init_kwargs
).to("cuda")
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save(f"flux.1-schnell-qencoder-{precision}.png")
if __name__ == "__main__":
main()
...@@ -81,6 +81,8 @@ class W4Linear(nn.Module): ...@@ -81,6 +81,8 @@ class W4Linear(nn.Module):
self.group_size, self.group_size,
) )
else: else:
if self.group_size != 128:
raise NotImplementedError("Kernel currently only supports group_size=128.")
out = gemm_awq(x, self.qweight, self.scales, self.scaled_zeros) out = gemm_awq(x, self.qweight, self.scales, self.scaled_zeros)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out return out
......
import os import os
import torch import torch
from accelerate import init_empty_weights
from huggingface_hub import constants, hf_hub_download from huggingface_hub import constants, hf_hub_download
from safetensors.torch import load_file from safetensors.torch import load_file
from torch import nn from torch import nn
from transformers import PretrainedConfig, T5EncoderModel from transformers import PretrainedConfig, T5Config, T5EncoderModel
from .linear import W4Linear from .linear import W4Linear
def quantize_t5_encoder( class NunchakuT5EncoderModel(T5EncoderModel):
t5_encoder: nn.Module, @classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str | os.PathLike, pretrained_model_name_or_path: str | os.PathLike,
config: PretrainedConfig | str | os.PathLike | None = None,
cache_dir: str | os.PathLike | None = None, cache_dir: str | os.PathLike | None = None,
force_download: bool = False, force_download: bool = False,
local_files_only: bool = False, local_files_only: bool = False,
token: str | bool | None = None, token: str | bool | None = None,
revision: str = "main", revision: str = "main",
**kwargs, **kwargs,
): ):
subfolder = kwargs.get("subfolder", None) subfolder = kwargs.get("subfolder", None)
if os.path.exists(pretrained_model_name_or_path): if os.path.exists(pretrained_model_name_or_path):
dirname = ( dirname = (
...@@ -27,46 +31,49 @@ def quantize_t5_encoder( ...@@ -27,46 +31,49 @@ def quantize_t5_encoder(
else os.path.join(pretrained_model_name_or_path, subfolder) else os.path.join(pretrained_model_name_or_path, subfolder)
) )
qmodel_path = os.path.join(dirname, "svdq-t5.safetensors") qmodel_path = os.path.join(dirname, "svdq-t5.safetensors")
config_path = os.path.join(dirname, "config.json")
else: else:
qmodel_path = hf_hub_download( shared_kwargs = {
repo_id=pretrained_model_name_or_path, "repo_id": pretrained_model_name_or_path,
filename="svdq-t5.safetensors", "subfolder": subfolder,
subfolder=subfolder, "repo_type": "model",
repo_type="model", "revision": revision,
revision=revision, "library_name": kwargs.get("library_name"),
library_name=kwargs.get("library_name", None), "library_version": kwargs.get("library_version"),
library_version=kwargs.get("library_version", None), "cache_dir": cache_dir,
cache_dir=cache_dir, "local_dir": kwargs.get("local_dir"),
local_dir=kwargs.get("local_dir", None), "user_agent": kwargs.get("user_agent"),
user_agent=kwargs.get("user_agent", None), "force_download": force_download,
force_download=force_download, "proxies": kwargs.get("proxies"),
proxies=kwargs.get("proxies", None), "etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
etag_timeout=kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT), "token": token,
token=token, "local_files_only": local_files_only,
local_files_only=local_files_only, "headers": kwargs.get("headers"),
headers=kwargs.get("headers", None), "endpoint": kwargs.get("endpoint"),
endpoint=kwargs.get("endpoint", None), "resume_download": kwargs.get("resume_download"),
resume_download=kwargs.get("resume_download", None), "force_filename": kwargs.get("force_filename"),
force_filename=kwargs.get("force_filename", None), "local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
local_dir_use_symlinks=kwargs.get("local_dir_use_symlinks", "auto"), }
) qmodel_path = hf_hub_download(filename="svdq-t5.safetensors", **shared_kwargs)
config_path = hf_hub_download(filename="config.json", **shared_kwargs)
# Load the config file
config = T5Config.from_json_file(config_path)
# Initialize model on 'meta' device (no memory allocation for weights)
with init_empty_weights():
t5_encoder = T5EncoderModel(config).to(kwargs.get("torch_dtype", torch.bfloat16))
t5_encoder.eval()
# Load the model weights from the safetensors file
state_dict = load_file(qmodel_path) state_dict = load_file(qmodel_path)
qlayer_suffix = tuple(kwargs.get("qlayer_suffix", (".q", ".k", ".v", ".o", ".wi_0")))
named_modules = {} named_modules = {}
for name, module in t5_encoder.named_modules(): for name, module in t5_encoder.named_modules():
assert isinstance(name, str) assert isinstance(name, str)
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
if f"{name}.qweight" in state_dict and name.endswith(qlayer_suffix): if f"{name}.qweight" in state_dict:
print(f"Switching {name} to W4Linear") print(f"Switching {name} to W4Linear")
qmodule = W4Linear.from_linear(module, group_size=128, init_only=False) qmodule = W4Linear.from_linear(module, group_size=128, init_only=True)
# qmodule.qweight.data.copy_(state_dict[f"{name}.qweight"])
# if qmodule.bias is not None:
# qmodule.bias.data.copy_(state_dict[f"{name}.bias"])
# qmodule.scales.data.copy_(state_dict[f"{name}.scales"])
# qmodule.scaled_zeros.data.copy_(state_dict[f"{name}.scaled_zeros"])
# modeling_t5.py: T5DenseGatedActDense needs dtype of weight # modeling_t5.py: T5DenseGatedActDense needs dtype of weight
qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device) qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device)
...@@ -74,52 +81,11 @@ def quantize_t5_encoder( ...@@ -74,52 +81,11 @@ def quantize_t5_encoder(
setattr(named_modules[parent_name], child_name, qmodule) setattr(named_modules[parent_name], child_name, qmodule)
else: else:
named_modules[name] = module named_modules[name] = module
return t5_encoder
device = kwargs.get("device", "cuda")
if isinstance(device, str):
device = torch.device(device)
t5_encoder.to_empty(device=device)
t5_encoder.load_state_dict(state_dict, strict=True)
class NunchakuT5EncoderModel(T5EncoderModel):
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str | os.PathLike,
*model_args,
config: PretrainedConfig | str | os.PathLike | None = None,
cache_dir: str | os.PathLike | None = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: str | bool | None = None,
revision: str = "main",
use_safetensors: bool = None,
weights_only: bool = True,
**kwargs,
):
t5_encoder = (
super(NunchakuT5EncoderModel, cls)
.from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
weights_only=weights_only,
**kwargs,
)
.to(kwargs.get("torch_dtype", torch.bfloat16))
)
t5_encoder = quantize_t5_encoder(
t5_encoder=t5_encoder,
pretrained_model_name_or_path=pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
**kwargs,
)
return t5_encoder return t5_encoder
import pytest
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
"height,width,use_qencoder,expected_lpips", [(1024, 1024, True, 0.136 if get_precision() == "int4" else 0.145)]
)
def test_flux_schnell_qencoder(height: int, width: int, use_qencoder: bool, expected_lpips: float):
run_test(
precision=get_precision(), height=height, width=width, use_qencoder=use_qencoder, expected_lpips=expected_lpips
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment