Unverified Commit 1334f0ea authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Merge pull request #406 from mit-han-lab/dev

Better 4-bit t5
Fix the removal LoRA
Add nightly windows wheels
parents 5ed26dc2 06d90947
...@@ -103,3 +103,45 @@ jobs: ...@@ -103,3 +103,45 @@ jobs:
run: | run: |
cd .. cd ..
rm -rf *nunchaku* rm -rf *nunchaku*
windows-wheels:
name: Build the windows nightly wheels
runs-on: [self-hosted, windows-build]
needs: tag
if: needs.tag.outputs.is_dev == 'true'
steps:
- name: Checkout to the tag
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ needs.tag.outputs.tag_name }}
submodules: true
- name: Show current commit
run: git log -1 --oneline
- name: Build wheels
run: |
C:\Users\muyang\miniconda3\condabin\activate.bat activate
scripts\build_all_windows_wheels.cmd
- name: Upload wheels to GitHub Release
uses: softprops/action-gh-release@v2
with:
files: dist/*.whl
name: Nunchaku Nightly ${{ needs.tag.outputs.tag_name }}
tag_name: ${{ needs.tag.outputs.tag_name }}
prerelease: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
windows-clean-up:
if: always()
needs: [ windows-wheels ]
runs-on: [ self-hosted, windows-build ]
steps:
- name: Clean up
run: |
cd ..
powershell -Command "Remove-Item -Recurse -Force *nunchaku*"
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.models.text_encoders.t5_encoder import NunchakuT5EncoderModel
from nunchaku.utils import get_precision
def main():
pipeline_init_kwargs = {}
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/svdq-flux.1-t5")
pipeline_init_kwargs["text_encoder_2"] = text_encoder_2
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-schnell")
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16, **pipeline_init_kwargs
).to("cuda")
image = pipeline(
"A cat holding a sign that says hello world", width=1024, height=1024, num_inference_steps=4, guidance_scale=0
).images[0]
image.save(f"flux.1-schnell-qencoder-{precision}.png")
if __name__ == "__main__":
main()
...@@ -4,6 +4,7 @@ import warnings ...@@ -4,6 +4,7 @@ import warnings
import torch import torch
from diffusers.loaders import FluxLoraLoaderMixin from diffusers.loaders import FluxLoraLoaderMixin
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
from safetensors.torch import save_file from safetensors.torch import save_file
from .utils import load_state_dict_in_safetensors from .utils import load_state_dict_in_safetensors
...@@ -21,6 +22,7 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N ...@@ -21,6 +22,7 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N
tensors[k] = v.to(torch.bfloat16) tensors[k] = v.to(torch.bfloat16)
new_tensors, alphas = FluxLoraLoaderMixin.lora_state_dict(tensors, return_alphas=True) new_tensors, alphas = FluxLoraLoaderMixin.lora_state_dict(tensors, return_alphas=True)
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
if alphas is not None and len(alphas) > 0: if alphas is not None and len(alphas) > 0:
warnings.warn("Alpha values are not used in the conversion to diffusers format.") warnings.warn("Alpha values are not used in the conversion to diffusers format.")
......
...@@ -81,6 +81,8 @@ class W4Linear(nn.Module): ...@@ -81,6 +81,8 @@ class W4Linear(nn.Module):
self.group_size, self.group_size,
) )
else: else:
if self.group_size != 128:
raise NotImplementedError("Kernel currently only supports group_size=128.")
out = gemm_awq(x, self.qweight, self.scales, self.scaled_zeros) out = gemm_awq(x, self.qweight, self.scales, self.scaled_zeros)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out return out
......
import os import os
import torch import torch
from accelerate import init_empty_weights
from huggingface_hub import constants, hf_hub_download from huggingface_hub import constants, hf_hub_download
from safetensors.torch import load_file from safetensors.torch import load_file
from torch import nn from torch import nn
from transformers import PretrainedConfig, T5EncoderModel from transformers import PretrainedConfig, T5Config, T5EncoderModel
from .linear import W4Linear from .linear import W4Linear
def quantize_t5_encoder(
t5_encoder: nn.Module,
pretrained_model_name_or_path: str | os.PathLike,
cache_dir: str | os.PathLike | None = None,
force_download: bool = False,
local_files_only: bool = False,
token: str | bool | None = None,
revision: str = "main",
**kwargs,
):
subfolder = kwargs.get("subfolder", None)
if os.path.exists(pretrained_model_name_or_path):
dirname = (
pretrained_model_name_or_path
if subfolder is None
else os.path.join(pretrained_model_name_or_path, subfolder)
)
qmodel_path = os.path.join(dirname, "svdq-t5.safetensors")
else:
qmodel_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="svdq-t5.safetensors",
subfolder=subfolder,
repo_type="model",
revision=revision,
library_name=kwargs.get("library_name", None),
library_version=kwargs.get("library_version", None),
cache_dir=cache_dir,
local_dir=kwargs.get("local_dir", None),
user_agent=kwargs.get("user_agent", None),
force_download=force_download,
proxies=kwargs.get("proxies", None),
etag_timeout=kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
token=token,
local_files_only=local_files_only,
headers=kwargs.get("headers", None),
endpoint=kwargs.get("endpoint", None),
resume_download=kwargs.get("resume_download", None),
force_filename=kwargs.get("force_filename", None),
local_dir_use_symlinks=kwargs.get("local_dir_use_symlinks", "auto"),
)
state_dict = load_file(qmodel_path)
qlayer_suffix = tuple(kwargs.get("qlayer_suffix", (".q", ".k", ".v", ".o", ".wi_0")))
named_modules = {}
for name, module in t5_encoder.named_modules():
assert isinstance(name, str)
if isinstance(module, nn.Linear):
if f"{name}.qweight" in state_dict and name.endswith(qlayer_suffix):
print(f"Switching {name} to W4Linear")
qmodule = W4Linear.from_linear(module, group_size=128, init_only=False)
# qmodule.qweight.data.copy_(state_dict[f"{name}.qweight"])
# if qmodule.bias is not None:
# qmodule.bias.data.copy_(state_dict[f"{name}.bias"])
# qmodule.scales.data.copy_(state_dict[f"{name}.scales"])
# qmodule.scaled_zeros.data.copy_(state_dict[f"{name}.scaled_zeros"])
# modeling_t5.py: T5DenseGatedActDense needs dtype of weight
qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device)
parent_name, child_name = name.rsplit(".", 1)
setattr(named_modules[parent_name], child_name, qmodule)
else:
named_modules[name] = module
return t5_encoder
class NunchakuT5EncoderModel(T5EncoderModel): class NunchakuT5EncoderModel(T5EncoderModel):
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
pretrained_model_name_or_path: str | os.PathLike, pretrained_model_name_or_path: str | os.PathLike,
*model_args,
config: PretrainedConfig | str | os.PathLike | None = None, config: PretrainedConfig | str | os.PathLike | None = None,
cache_dir: str | os.PathLike | None = None, cache_dir: str | os.PathLike | None = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False, force_download: bool = False,
local_files_only: bool = False, local_files_only: bool = False,
token: str | bool | None = None, token: str | bool | None = None,
revision: str = "main", revision: str = "main",
use_safetensors: bool = None,
weights_only: bool = True,
**kwargs, **kwargs,
): ):
t5_encoder = ( subfolder = kwargs.get("subfolder", None)
super(NunchakuT5EncoderModel, cls) if os.path.exists(pretrained_model_name_or_path):
.from_pretrained( dirname = (
pretrained_model_name_or_path, pretrained_model_name_or_path
*model_args, if subfolder is None
config=config, else os.path.join(pretrained_model_name_or_path, subfolder)
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
weights_only=weights_only,
**kwargs,
) )
.to(kwargs.get("torch_dtype", torch.bfloat16)) qmodel_path = os.path.join(dirname, "svdq-t5.safetensors")
) config_path = os.path.join(dirname, "config.json")
t5_encoder = quantize_t5_encoder( else:
t5_encoder=t5_encoder, shared_kwargs = {
pretrained_model_name_or_path=pretrained_model_name_or_path, "repo_id": pretrained_model_name_or_path,
cache_dir=cache_dir, "subfolder": subfolder,
force_download=force_download, "repo_type": "model",
local_files_only=local_files_only, "revision": revision,
token=token, "library_name": kwargs.get("library_name"),
revision=revision, "library_version": kwargs.get("library_version"),
**kwargs, "cache_dir": cache_dir,
) "local_dir": kwargs.get("local_dir"),
"user_agent": kwargs.get("user_agent"),
"force_download": force_download,
"proxies": kwargs.get("proxies"),
"etag_timeout": kwargs.get("etag_timeout", constants.DEFAULT_ETAG_TIMEOUT),
"token": token,
"local_files_only": local_files_only,
"headers": kwargs.get("headers"),
"endpoint": kwargs.get("endpoint"),
"resume_download": kwargs.get("resume_download"),
"force_filename": kwargs.get("force_filename"),
"local_dir_use_symlinks": kwargs.get("local_dir_use_symlinks", "auto"),
}
qmodel_path = hf_hub_download(filename="svdq-t5.safetensors", **shared_kwargs)
config_path = hf_hub_download(filename="config.json", **shared_kwargs)
# Load the config file
config = T5Config.from_json_file(config_path)
# Initialize model on 'meta' device (no memory allocation for weights)
with init_empty_weights():
t5_encoder = T5EncoderModel(config).to(kwargs.get("torch_dtype", torch.bfloat16))
t5_encoder.eval()
# Load the model weights from the safetensors file
state_dict = load_file(qmodel_path)
named_modules = {}
for name, module in t5_encoder.named_modules():
assert isinstance(name, str)
if isinstance(module, nn.Linear):
if f"{name}.qweight" in state_dict:
print(f"Switching {name} to W4Linear")
qmodule = W4Linear.from_linear(module, group_size=128, init_only=True)
# modeling_t5.py: T5DenseGatedActDense needs dtype of weight
qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device)
parent_name, child_name = name.rsplit(".", 1)
setattr(named_modules[parent_name], child_name, qmodule)
else:
named_modules[name] = module
device = kwargs.get("device", "cuda")
if isinstance(device, str):
device = torch.device(device)
t5_encoder.to_empty(device=device)
t5_encoder.load_state_dict(state_dict, strict=True)
return t5_encoder return t5_encoder
import pytest
from nunchaku.utils import get_precision, is_turing
from .utils import run_test
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
@pytest.mark.parametrize(
"height,width,use_qencoder,expected_lpips", [(1024, 1024, True, 0.136 if get_precision() == "int4" else 0.145)]
)
def test_flux_schnell_qencoder(height: int, width: int, use_qencoder: bool, expected_lpips: float):
run_test(
precision=get_precision(), height=height, width=width, use_qencoder=use_qencoder, expected_lpips=expected_lpips
)
...@@ -54,7 +54,7 @@ from .utils import already_generate, compute_lpips, offload_pipeline ...@@ -54,7 +54,7 @@ from .utils import already_generate, compute_lpips, offload_pipeline
"waterfall", "waterfall",
23, 23,
0.6, 0.6,
0.253 if get_precision() == "int4" else 0.226, 0.253 if get_precision() == "int4" else 0.254,
), ),
], ],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment