Unverified Commit e24941b2 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

[Single File] Add GGUF support (#9964)



* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update src/diffusers/quantizers/gguf/utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update docs/source/en/quantization/gguf.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* update

* update

* update

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent f9d5a932
...@@ -357,6 +357,8 @@ jobs: ...@@ -357,6 +357,8 @@ jobs:
config: config:
- backend: "bitsandbytes" - backend: "bitsandbytes"
test_location: "bnb" test_location: "bnb"
- backend: "gguf"
test_location: "gguf"
runs-on: runs-on:
group: aws-g6e-xlarge-plus group: aws-g6e-xlarge-plus
container: container:
......
...@@ -157,6 +157,8 @@ ...@@ -157,6 +157,8 @@
title: Getting Started title: Getting Started
- local: quantization/bitsandbytes - local: quantization/bitsandbytes
title: bitsandbytes title: bitsandbytes
- local: quantization/gguf
title: gguf
- local: quantization/torchao - local: quantization/torchao
title: torchao title: torchao
title: Quantization Methods title: Quantization Methods
......
...@@ -28,6 +28,9 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui ...@@ -28,6 +28,9 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
[[autodoc]] BitsAndBytesConfig [[autodoc]] BitsAndBytesConfig
## GGUFQuantizationConfig
[[autodoc]] GGUFQuantizationConfig
## TorchAoConfig ## TorchAoConfig
[[autodoc]] TorchAoConfig [[autodoc]] TorchAoConfig
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# GGUF
The GGUF file format is typically used to store models for inference with [GGML](https://github.com/ggerganov/ggml) and supports a variety of block wise quantization options. Diffusers supports loading checkpoints prequantized and saved in the GGUF format via `from_single_file` loading with Model classes. Loading GGUF checkpoints via Pipelines is currently not supported.
The following example will load the [FLUX.1 DEV](https://huggingface.co/black-forest-labs/FLUX.1-dev) transformer model using the GGUF Q2_K quantization variant.
Before starting please install gguf in your environment
```shell
pip install -U gguf
```
Since GGUF is a single file format, use [`~FromSingleFileMixin.from_single_file`] to load the model and pass in the [`GGUFQuantizationConfig`].
When using GGUF checkpoints, the quantized weights remain in a low memory `dtype`(typically `torch.unint8`) and are dynamically dequantized and cast to the configured `compute_dtype` during each module's forward pass through the model. The `GGUFQuantizationConfig` allows you to set the `compute_dtype`.
The functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF), who created the Pytorch ports of the original (`numpy`)[https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/quants.py] implementation by [compilade](https://github.com/compilade).
```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
ckpt_path = (
"https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
)
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
generator=torch.manual_seed(0),
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt).images[0]
image.save("flux-gguf.png")
```
## Supported Quantization Types
- BF16
- Q4_0
- Q4_1
- Q5_0
- Q5_1
- Q8_0
- Q2_K
- Q3_K
- Q4_K
- Q5_K
- Q6_K
...@@ -17,7 +17,7 @@ Quantization techniques focus on representing data with less information while a ...@@ -17,7 +17,7 @@ Quantization techniques focus on representing data with less information while a
<Tip> <Tip>
Interested in adding a new quantization method to Transformers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method. Interested in adding a new quantization method to Diffusers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
</Tip> </Tip>
...@@ -32,4 +32,9 @@ If you are new to the quantization field, we recommend you to check out these be ...@@ -32,4 +32,9 @@ If you are new to the quantization field, we recommend you to check out these be
## When to use what? ## When to use what?
Diffusers supports [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) and [torchao](https://github.com/pytorch/ao). Refer to this [table](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) to help you determine which quantization backend to use. Diffusers currently supports the following quantization methods.
\ No newline at end of file - [BitsandBytes]()
- [TorchAO]()
- [GGUF]()
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
...@@ -31,7 +31,7 @@ _import_structure = { ...@@ -31,7 +31,7 @@ _import_structure = {
"loaders": ["FromOriginalModelMixin"], "loaders": ["FromOriginalModelMixin"],
"models": [], "models": [],
"pipelines": [], "pipelines": [],
"quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig"], "quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
"schedulers": [], "schedulers": [],
"utils": [ "utils": [
"OptionalDependencyNotAvailable", "OptionalDependencyNotAvailable",
...@@ -569,7 +569,7 @@ else: ...@@ -569,7 +569,7 @@ else:
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .quantizers.quantization_config import BitsAndBytesConfig, TorchAoConfig from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig
try: try:
if not is_onnx_available(): if not is_onnx_available():
......
...@@ -17,8 +17,10 @@ import re ...@@ -17,8 +17,10 @@ import re
from contextlib import nullcontext from contextlib import nullcontext
from typing import Optional from typing import Optional
import torch
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging from ..utils import deprecate, is_accelerate_available, logging
from .single_file_utils import ( from .single_file_utils import (
SingleFileComponentError, SingleFileComponentError,
...@@ -214,6 +216,8 @@ class FromOriginalModelMixin: ...@@ -214,6 +216,8 @@ class FromOriginalModelMixin:
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None) torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
if isinstance(pretrained_model_link_or_path_or_dict, dict): if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict checkpoint = pretrained_model_link_or_path_or_dict
...@@ -227,6 +231,12 @@ class FromOriginalModelMixin: ...@@ -227,6 +231,12 @@ class FromOriginalModelMixin:
local_files_only=local_files_only, local_files_only=local_files_only,
revision=revision, revision=revision,
) )
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
hf_quantizer.validate_environment()
else:
hf_quantizer = None
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name] mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
...@@ -309,8 +319,36 @@ class FromOriginalModelMixin: ...@@ -309,8 +319,36 @@ class FromOriginalModelMixin:
with ctx(): with ctx():
model = cls.from_config(diffusers_model_config) model = cls.from_config(diffusers_model_config)
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
)
if use_keep_in_fp32_modules:
keep_in_fp32_modules = cls._keep_in_fp32_modules
if not isinstance(keep_in_fp32_modules, list):
keep_in_fp32_modules = [keep_in_fp32_modules]
else:
keep_in_fp32_modules = []
if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model,
device_map=None,
state_dict=diffusers_format_checkpoint,
keep_in_fp32_modules=keep_in_fp32_modules,
)
if is_accelerate_available(): if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) param_device = torch.device(device) if device else torch.device("cpu")
unexpected_keys = load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
dtype=torch_dtype,
device=param_device,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
)
else: else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False) _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
...@@ -324,7 +362,11 @@ class FromOriginalModelMixin: ...@@ -324,7 +362,11 @@ class FromOriginalModelMixin:
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
) )
if torch_dtype is not None: if hf_quantizer is not None:
hf_quantizer.postprocess_model(model)
model.hf_quantizer = hf_quantizer
if torch_dtype is not None and hf_quantizer is None:
model.to(torch_dtype) model.to(torch_dtype)
model.eval() model.eval()
......
...@@ -81,8 +81,14 @@ CHECKPOINT_KEY_NAMES = { ...@@ -81,8 +81,14 @@ CHECKPOINT_KEY_NAMES = {
"open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight", "open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight", "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
"stable_cascade_stage_c": "clip_txt_mapper.weight", "stable_cascade_stage_c": "clip_txt_mapper.weight",
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias", "sd3": [
"sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight", "joint_blocks.0.context_block.adaLN_modulation.1.bias",
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
],
"sd35_large": [
"joint_blocks.37.x_block.mlp.fc1.weight",
"model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
],
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe", "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias", "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
...@@ -542,13 +548,20 @@ def infer_diffusers_model_type(checkpoint): ...@@ -542,13 +548,20 @@ def infer_diffusers_model_type(checkpoint):
): ):
model_type = "stable_cascade_stage_b" model_type = "stable_cascade_stage_b"
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216: elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any(
if checkpoint["model.diffusion_model.pos_embed"].shape[1] == 36864: checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"]
):
if "model.diffusion_model.pos_embed" in checkpoint:
key = "model.diffusion_model.pos_embed"
else:
key = "pos_embed"
if checkpoint[key].shape[1] == 36864:
model_type = "sd3" model_type = "sd3"
elif checkpoint["model.diffusion_model.pos_embed"].shape[1] == 147456: elif checkpoint[key].shape[1] == 147456:
model_type = "sd35_medium" model_type = "sd35_medium"
elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint: elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]):
model_type = "sd35_large" model_type = "sd35_large"
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint: elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import importlib import importlib
import inspect import inspect
import os import os
from array import array
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import List, Optional, Union
...@@ -26,6 +27,7 @@ import torch ...@@ -26,6 +27,7 @@ import torch
from huggingface_hub.utils import EntryNotFoundError from huggingface_hub.utils import EntryNotFoundError
from ..utils import ( from ..utils import (
GGUF_FILE_EXTENSION,
SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION, SAFETENSORS_FILE_EXTENSION,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
...@@ -33,6 +35,8 @@ from ..utils import ( ...@@ -33,6 +35,8 @@ from ..utils import (
_get_model_file, _get_model_file,
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_gguf_available,
is_torch_available,
is_torch_version, is_torch_version,
logging, logging,
) )
...@@ -139,6 +143,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ ...@@ -139,6 +143,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
file_extension = os.path.basename(checkpoint_file).split(".")[-1] file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION: if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu") return safetensors.torch.load_file(checkpoint_file, device="cpu")
elif file_extension == GGUF_FILE_EXTENSION:
return load_gguf_checkpoint(checkpoint_file)
else: else:
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
return torch.load( return torch.load(
...@@ -211,13 +217,14 @@ def load_model_dict_into_meta( ...@@ -211,13 +217,14 @@ def load_model_dict_into_meta(
set_module_kwargs["dtype"] = dtype set_module_kwargs["dtype"] = dtype
# bnb params are flattened. # bnb params are flattened.
# gguf quants have a different shape based on the type of quantization applied
if empty_state_dict[param_name].shape != param.shape: if empty_state_dict[param_name].shape != param.shape:
if ( if (
is_quantized is_quantized
and hf_quantizer.pre_quantized and hf_quantizer.pre_quantized
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
): ):
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape) hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
else: else:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError( raise ValueError(
...@@ -396,3 +403,78 @@ def _fetch_index_file_legacy( ...@@ -396,3 +403,78 @@ def _fetch_index_file_legacy(
index_file = None index_file = None
return index_file return index_file
def _gguf_parse_value(_value, data_type):
if not isinstance(data_type, list):
data_type = [data_type]
if len(data_type) == 1:
data_type = data_type[0]
array_data_type = None
else:
if data_type[0] != 9:
raise ValueError("Received multiple types, therefore expected the first type to indicate an array.")
data_type, array_data_type = data_type
if data_type in [0, 1, 2, 3, 4, 5, 10, 11]:
_value = int(_value[0])
elif data_type in [6, 12]:
_value = float(_value[0])
elif data_type in [7]:
_value = bool(_value[0])
elif data_type in [8]:
_value = array("B", list(_value)).tobytes().decode()
elif data_type in [9]:
_value = _gguf_parse_value(_value, array_data_type)
return _value
def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
"""
Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config
attributes.
Args:
gguf_checkpoint_path (`str`):
The path the to GGUF file to load
return_tensors (`bool`, defaults to `True`):
Whether to read the tensors from the file and return them. Not doing so is faster and only loads the
metadata in memory.
"""
if is_gguf_available() and is_torch_available():
import gguf
from gguf import GGUFReader
from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter
else:
logger.error(
"Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
"https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
)
raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
reader = GGUFReader(gguf_checkpoint_path)
parsed_parameters = {}
for tensor in reader.tensors:
name = tensor.name
quant_type = tensor.tensor_type
# if the tensor is a torch supported dtype do not use GGUFParameter
is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES:
_supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES])
raise ValueError(
(
f"{name} has a quantization type: {str(quant_type)} which is unsupported."
"\n\nCurrently the following quantization types are supported: \n\n"
f"{_supported_quants_str}"
"\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers"
)
)
weights = torch.from_numpy(tensor.data.copy())
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
return parsed_parameters
...@@ -1038,14 +1038,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -1038,14 +1038,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dtype_present_in_args = True dtype_present_in_args = True
break break
# Checks if the model has been loaded in 4-bit or 8-bit with BNB if getattr(self, "is_quantized", False):
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
if dtype_present_in_args: if dtype_present_in_args:
raise ValueError( raise ValueError(
"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the" "Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please "
" desired `dtype` by passing the correct `torch_dtype` argument." "use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`"
) )
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
if getattr(self, "is_loaded_in_8bit", False): if getattr(self, "is_loaded_in_8bit", False):
raise ValueError( raise ValueError(
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
......
...@@ -524,7 +524,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig ...@@ -524,7 +524,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
) )
else: else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks): for index_block, block in enumerate(self.single_transformer_blocks):
......
...@@ -15,23 +15,33 @@ ...@@ -15,23 +15,33 @@
Adapted from Adapted from
https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/auto.py https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/auto.py
""" """
import warnings import warnings
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig from .gguf import GGUFQuantizer
from .quantization_config import (
BitsAndBytesConfig,
GGUFQuantizationConfig,
QuantizationConfigMixin,
QuantizationMethod,
TorchAoConfig,
)
from .torchao import TorchAoHfQuantizer from .torchao import TorchAoHfQuantizer
AUTO_QUANTIZER_MAPPING = { AUTO_QUANTIZER_MAPPING = {
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer, "bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer, "bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
"gguf": GGUFQuantizer,
"torchao": TorchAoHfQuantizer, "torchao": TorchAoHfQuantizer,
} }
AUTO_QUANTIZATION_CONFIG_MAPPING = { AUTO_QUANTIZATION_CONFIG_MAPPING = {
"bitsandbytes_4bit": BitsAndBytesConfig, "bitsandbytes_4bit": BitsAndBytesConfig,
"bitsandbytes_8bit": BitsAndBytesConfig, "bitsandbytes_8bit": BitsAndBytesConfig,
"gguf": GGUFQuantizationConfig,
"torchao": TorchAoConfig, "torchao": TorchAoConfig,
} }
......
...@@ -204,7 +204,10 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): ...@@ -204,7 +204,10 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
module._parameters[tensor_name] = new_value module._parameters[tensor_name] = new_value
def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape): def check_quantized_param_shape(self, param_name, current_param, loaded_param):
current_param_shape = current_param.shape
loaded_param_shape = loaded_param.shape
n = current_param_shape.numel() n = current_param_shape.numel()
inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1) inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
if loaded_param_shape != inferred_shape: if loaded_param_shape != inferred_shape:
......
from .gguf_quantizer import GGUFQuantizer
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from ..base import DiffusersQuantizer
if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin
from ...utils import (
get_module_from_name,
is_accelerate_available,
is_accelerate_version,
is_gguf_available,
is_gguf_version,
is_torch_available,
logging,
)
if is_torch_available() and is_gguf_available():
import torch
from .utils import (
GGML_QUANT_SIZES,
GGUFParameter,
_dequantize_gguf_and_restore_linear,
_quant_shape_from_byte_shape,
_replace_with_gguf_linear,
)
logger = logging.get_logger(__name__)
class GGUFQuantizer(DiffusersQuantizer):
use_keep_in_fp32_modules = True
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
self.compute_dtype = quantization_config.compute_dtype
self.pre_quantized = quantization_config.pre_quantized
self.modules_to_not_convert = quantization_config.modules_to_not_convert
if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]
def validate_environment(self, *args, **kwargs):
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
raise ImportError(
"Loading GGUF Parameters requires `accelerate` installed in your enviroment: `pip install 'accelerate>=0.26.0'`"
)
if not is_gguf_available() or is_gguf_version("<", "0.10.0"):
raise ImportError(
"To load GGUF format files you must have `gguf` installed in your environment: `pip install gguf>=0.10.0`"
)
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
# need more space for buffers that are created during quantization
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
return max_memory
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
if target_dtype != torch.uint8:
logger.info(f"target_dtype {target_dtype} is replaced by `torch.uint8` for GGUF quantization")
return torch.uint8
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
torch_dtype = self.compute_dtype
return torch_dtype
def check_quantized_param_shape(self, param_name, current_param, loaded_param):
loaded_param_shape = loaded_param.shape
current_param_shape = current_param.shape
quant_type = loaded_param.quant_type
block_size, type_size = GGML_QUANT_SIZES[quant_type]
inferred_shape = _quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size)
if inferred_shape != current_param_shape:
raise ValueError(
f"{param_name} has an expected quantized shape of: {inferred_shape}, but receieved shape: {loaded_param_shape}"
)
return True
def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: Union["GGUFParameter", "torch.Tensor"],
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
if isinstance(param_value, GGUFParameter):
return True
return False
def create_quantized_param(
self,
model: "ModelMixin",
param_value: Union["GGUFParameter", "torch.Tensor"],
param_name: str,
target_device: "torch.device",
state_dict: Optional[Dict[str, Any]] = None,
unexpected_keys: Optional[List[str]] = None,
):
module, tensor_name = get_module_from_name(model, param_name)
if tensor_name not in module._parameters and tensor_name not in module._buffers:
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
if tensor_name in module._parameters:
module._parameters[tensor_name] = param_value.to(target_device)
if tensor_name in module._buffers:
module._buffers[tensor_name] = param_value.to(target_device)
def _process_model_before_weight_loading(
self,
model: "ModelMixin",
device_map,
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
state_dict = kwargs.get("state_dict", None)
self.modules_to_not_convert.extend(keep_in_fp32_modules)
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
_replace_with_gguf_linear(
model, self.compute_dtype, state_dict, modules_to_not_convert=self.modules_to_not_convert
)
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
return model
@property
def is_serializable(self):
return False
@property
def is_trainable(self) -> bool:
return False
def _dequantize(self, model):
is_model_on_cpu = model.device.type == "cpu"
if is_model_on_cpu:
logger.info(
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
)
model.to(torch.cuda.current_device())
model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert)
if is_model_on_cpu:
model.to("cpu")
return model
# Copyright 2024 The HuggingFace Team and City96. All rights reserved.
# #
# # Licensed under the Apache License, Version 2.0 (the "License");
# # you may not use this file except in compliance with the License.
# # You may obtain a copy of the License at
# #
# # http://www.apache.org/licenses/LICENSE-2.0
# #
# # Unless required by applicable law or agreed to in writing, software
# # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# # See the License for the specific language governing permissions and
# # limitations under the License.
import inspect
from contextlib import nullcontext
import gguf
import torch
import torch.nn as nn
from ...utils import is_accelerate_available
if is_accelerate_available():
import accelerate
from accelerate import init_empty_weights
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
def _create_accelerate_new_hook(old_hook):
r"""
Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of:
https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with
some changes
"""
old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
old_hook_attr = old_hook.__dict__
filtered_old_hook_attr = {}
old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
for k in old_hook_attr.keys():
if k in old_hook_init_signature.parameters:
filtered_old_hook_attr[k] = old_hook_attr[k]
new_hook = old_hook_cls(**filtered_old_hook_attr)
return new_hook
def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix="", modules_to_not_convert=[]):
def _should_convert_to_gguf(state_dict, prefix):
weight_key = prefix + "weight"
return weight_key in state_dict and isinstance(state_dict[weight_key], GGUFParameter)
has_children = list(model.children())
if not has_children:
return
for name, module in model.named_children():
module_prefix = prefix + name + "."
_replace_with_gguf_linear(module, compute_dtype, state_dict, module_prefix, modules_to_not_convert)
if (
isinstance(module, nn.Linear)
and _should_convert_to_gguf(state_dict, module_prefix)
and name not in modules_to_not_convert
):
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model._modules[name] = GGUFLinear(
module.in_features,
module.out_features,
module.bias is not None,
compute_dtype=compute_dtype,
)
model._modules[name].source_cls = type(module)
# Force requires_grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
return model
def _dequantize_gguf_and_restore_linear(model, modules_to_not_convert=[]):
for name, module in model.named_children():
if isinstance(module, GGUFLinear) and name not in modules_to_not_convert:
device = module.weight.device
bias = getattr(module, "bias", None)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
new_module = nn.Linear(
module.in_features,
module.out_features,
module.bias is not None,
device=device,
)
new_module.weight = nn.Parameter(dequantize_gguf_tensor(module.weight))
if bias is not None:
new_module.bias = bias
# Create a new hook and attach it in case we use accelerate
if hasattr(module, "_hf_hook"):
old_hook = module._hf_hook
new_hook = _create_accelerate_new_hook(old_hook)
remove_hook_from_module(module)
add_hook_to_module(new_module, new_hook)
new_module.to(device)
model._modules[name] = new_module
has_children = list(module.children())
if has_children:
_dequantize_gguf_and_restore_linear(module, modules_to_not_convert)
return model
# dequantize operations based on torch ports of GGUF dequantize_functions
# from City96
# more info: https://github.com/city96/ComfyUI-GGUF/blob/main/dequant.py
QK_K = 256
K_SCALE_SIZE = 12
def to_uint32(x):
x = x.view(torch.uint8).to(torch.int32)
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
def split_block_dims(blocks, *args):
n_max = blocks.shape[1]
dims = list(args) + [n_max - sum(args)]
return torch.split(blocks, dims, dim=1)
def get_scale_min(scales):
n_blocks = scales.shape[0]
scales = scales.view(torch.uint8)
scales = scales.reshape((n_blocks, 3, 4))
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None):
d, x = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(dtype)
x = x.view(torch.int8)
return d * x
def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, m, qh, qs = split_block_dims(blocks, 2, 2, 4)
d = d.view(torch.float16).to(dtype)
m = m.view(torch.float16).to(dtype)
qh = to_uint32(qh)
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8
).reshape(1, 1, 2, 1)
qh = (qh & 1).to(torch.uint8)
ql = (ql & 0x0F).reshape((n_blocks, -1))
qs = ql | (qh << 4)
return (d * qs) + m
def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, qh, qs = split_block_dims(blocks, 2, 4)
d = d.view(torch.float16).to(dtype)
qh = to_uint32(qh)
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8
).reshape(1, 1, 2, 1)
qh = (qh & 1).to(torch.uint8)
ql = (ql & 0x0F).reshape(n_blocks, -1)
qs = (ql | (qh << 4)).to(torch.int8) - 16
return d * qs
def dequantize_blocks_Q4_1(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, m, qs = split_block_dims(blocks, 2, 2)
d = d.view(torch.float16).to(dtype)
m = m.view(torch.float16).to(dtype)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8
).reshape(1, 1, 2, 1)
qs = (qs & 0x0F).reshape(n_blocks, -1)
return (d * qs) + m
def dequantize_blocks_Q4_0(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, qs = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(dtype)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8
).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
return d * qs
def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
(
ql,
qh,
scales,
d,
) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16)
scales = scales.view(torch.int8).to(dtype)
d = d.view(torch.float16).to(dtype)
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 2, 1)
)
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 4, 1)
)
qh = (qh & 0x03).reshape((n_blocks, -1, 32))
q = (ql | (qh << 4)).to(torch.int8) - 32
q = q.reshape((n_blocks, QK_K // 16, -1))
return (d * q).reshape((n_blocks, QK_K))
def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
sc, m = get_scale_min(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1))
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 2, 1)
)
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape(
(1, 1, 8, 1)
)
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
qh = (qh & 0x01).reshape((n_blocks, -1, 32))
q = ql | (qh << 4)
return (d * q - dm).reshape((n_blocks, QK_K))
def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
sc, m = get_scale_min(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1))
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 2, 1)
)
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
return (d * qs - dm).reshape((n_blocks, QK_K))
def dequantize_blocks_Q3_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12)
d = d.view(torch.float16).to(dtype)
lscales, hscales = scales[:, :8], scales[:, 8:]
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
(1, 2, 1)
)
lscales = lscales.reshape((n_blocks, 16))
hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor(
[0, 2, 4, 6], device=d.device, dtype=torch.uint8
).reshape((1, 4, 1))
hscales = hscales.reshape((n_blocks, 16))
scales = (lscales & 0x0F) | ((hscales & 0x03) << 4)
scales = scales.to(torch.int8) - 32
dl = (d * scales).reshape((n_blocks, 16, 1))
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 4, 1)
)
qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape(
(1, 1, 8, 1)
)
ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3
qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1
q = ql.to(torch.int8) - (qh << 2).to(torch.int8)
return (dl * q).reshape((n_blocks, QK_K))
def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
# (n_blocks, 16, 1)
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3
qs = qs.reshape((n_blocks, QK_K // 16, 16))
qs = dl * qs - ml
return qs.reshape((n_blocks, -1))
def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None):
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
GGML_QUANT_SIZES = gguf.GGML_QUANT_SIZES
dequantize_functions = {
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,
gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1,
gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K,
gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K,
gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K,
gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K,
gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K,
}
SUPPORTED_GGUF_QUANT_TYPES = list(dequantize_functions.keys())
def _quant_shape_from_byte_shape(shape, type_size, block_size):
return (*shape[:-1], shape[-1] // type_size * block_size)
def dequantize_gguf_tensor(tensor):
if not hasattr(tensor, "quant_type"):
return tensor
quant_type = tensor.quant_type
dequant_fn = dequantize_functions[quant_type]
block_size, type_size = GGML_QUANT_SIZES[quant_type]
tensor = tensor.view(torch.uint8)
shape = _quant_shape_from_byte_shape(tensor.shape, type_size, block_size)
n_blocks = tensor.numel() // type_size
blocks = tensor.reshape((n_blocks, type_size))
dequant = dequant_fn(blocks, block_size, type_size)
dequant = dequant.reshape(shape)
return dequant.as_tensor()
class GGUFParameter(torch.nn.Parameter):
def __new__(cls, data, requires_grad=False, quant_type=None):
data = data if data is not None else torch.empty(0)
self = torch.Tensor._make_subclass(cls, data, requires_grad)
self.quant_type = quant_type
return self
def as_tensor(self):
return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
result = super().__torch_function__(func, types, args, kwargs)
# When converting from original format checkpoints we often use splits, cats etc on tensors
# this method ensures that the returned tensor type from those operations remains GGUFParameter
# so that we preserve quant_type information
quant_type = None
for arg in args:
if isinstance(arg, list) and (arg[0], GGUFParameter):
quant_type = arg[0].quant_type
break
if isinstance(arg, GGUFParameter):
quant_type = arg.quant_type
break
if isinstance(result, torch.Tensor):
return cls(result, quant_type=quant_type)
# Handle tuples and lists
elif isinstance(result, (tuple, list)):
# Preserve the original type (tuple or list)
wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result]
return type(result)(wrapped)
else:
return result
class GGUFLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
bias=False,
compute_dtype=None,
device=None,
) -> None:
super().__init__(in_features, out_features, bias, device)
self.compute_dtype = compute_dtype
def forward(self, inputs):
weight = dequantize_gguf_tensor(self.weight)
weight = weight.to(self.compute_dtype)
bias = self.bias.to(self.compute_dtype)
output = torch.nn.functional.linear(inputs, weight, bias)
return output
...@@ -43,6 +43,7 @@ logger = logging.get_logger(__name__) ...@@ -43,6 +43,7 @@ logger = logging.get_logger(__name__)
class QuantizationMethod(str, Enum): class QuantizationMethod(str, Enum):
BITS_AND_BYTES = "bitsandbytes" BITS_AND_BYTES = "bitsandbytes"
GGUF = "gguf"
TORCHAO = "torchao" TORCHAO = "torchao"
...@@ -394,6 +395,29 @@ class BitsAndBytesConfig(QuantizationConfigMixin): ...@@ -394,6 +395,29 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
return serializable_config_dict return serializable_config_dict
@dataclass
class GGUFQuantizationConfig(QuantizationConfigMixin):
"""This is a config class for GGUF Quantization techniques.
Args:
compute_dtype: (`torch.dtype`, defaults to `torch.float32`):
This sets the computational type which might be different than the input type. For example, inputs might be
fp32, but computation can be set to bf16 for speedups.
"""
def __init__(self, compute_dtype: Optional["torch.dtype"] = None):
self.quant_method = QuantizationMethod.GGUF
self.compute_dtype = compute_dtype
self.pre_quantized = True
# TODO: (Dhruv) Add this as an init argument when we can support loading unquantized checkpoints.
self.modules_to_not_convert = None
if self.compute_dtype is None:
self.compute_dtype = torch.float32
@dataclass @dataclass
class TorchAoConfig(QuantizationConfigMixin): class TorchAoConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques. """This is a config class for torchao quantization/sparsity techniques.
......
...@@ -23,6 +23,7 @@ from .constants import ( ...@@ -23,6 +23,7 @@ from .constants import (
DEPRECATED_REVISION_ARGS, DEPRECATED_REVISION_ARGS,
DIFFUSERS_DYNAMIC_MODULE_NAME, DIFFUSERS_DYNAMIC_MODULE_NAME,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
GGUF_FILE_EXTENSION,
HF_MODULES_CACHE, HF_MODULES_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
MIN_PEFT_VERSION, MIN_PEFT_VERSION,
...@@ -66,6 +67,8 @@ from .import_utils import ( ...@@ -66,6 +67,8 @@ from .import_utils import (
is_bs4_available, is_bs4_available,
is_flax_available, is_flax_available,
is_ftfy_available, is_ftfy_available,
is_gguf_available,
is_gguf_version,
is_google_colab, is_google_colab,
is_inflect_available, is_inflect_available,
is_invisible_watermark_available, is_invisible_watermark_available,
......
...@@ -34,6 +34,7 @@ ONNX_WEIGHTS_NAME = "model.onnx" ...@@ -34,6 +34,7 @@ ONNX_WEIGHTS_NAME = "model.onnx"
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json" SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json"
SAFETENSORS_FILE_EXTENSION = "safetensors" SAFETENSORS_FILE_EXTENSION = "safetensors"
GGUF_FILE_EXTENSION = "gguf"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
......
...@@ -339,6 +339,14 @@ if _imageio_available: ...@@ -339,6 +339,14 @@ if _imageio_available:
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
_imageio_available = False _imageio_available = False
_is_gguf_available = importlib.util.find_spec("gguf") is not None
if _is_gguf_available:
try:
_gguf_version = importlib_metadata.version("gguf")
logger.debug(f"Successfully import gguf version {_gguf_version}")
except importlib_metadata.PackageNotFoundError:
_is_gguf_available = False
_is_torchao_available = importlib.util.find_spec("torchao") is not None _is_torchao_available = importlib.util.find_spec("torchao") is not None
if _is_torchao_available: if _is_torchao_available:
...@@ -469,6 +477,10 @@ def is_imageio_available(): ...@@ -469,6 +477,10 @@ def is_imageio_available():
return _imageio_available return _imageio_available
def is_gguf_available():
return _is_gguf_available
def is_torchao_available(): def is_torchao_available():
return _is_torchao_available return _is_torchao_available
...@@ -607,8 +619,13 @@ IMAGEIO_IMPORT_ERROR = """ ...@@ -607,8 +619,13 @@ IMAGEIO_IMPORT_ERROR = """
""" """
# docstyle-ignore # docstyle-ignore
GGUF_IMPORT_ERROR = """
{0} requires the gguf library but it was not found in your environment. You can install it with pip: `pip install gguf`
"""
TORCHAO_IMPORT_ERROR = """ TORCHAO_IMPORT_ERROR = """
{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install torchao` {0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install
torchao`
""" """
BACKENDS_MAPPING = OrderedDict( BACKENDS_MAPPING = OrderedDict(
...@@ -636,6 +653,7 @@ BACKENDS_MAPPING = OrderedDict( ...@@ -636,6 +653,7 @@ BACKENDS_MAPPING = OrderedDict(
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)), ("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)), ("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
] ]
) )
...@@ -793,6 +811,21 @@ def is_bitsandbytes_version(operation: str, version: str): ...@@ -793,6 +811,21 @@ def is_bitsandbytes_version(operation: str, version: str):
return compare_versions(parse(_bitsandbytes_version), operation, version) return compare_versions(parse(_bitsandbytes_version), operation, version)
def is_gguf_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _is_gguf_available:
return False
return compare_versions(parse(_gguf_version), operation, version)
def is_k_diffusion_version(operation: str, version: str): def is_k_diffusion_version(operation: str, version: str):
""" """
Compares the current k-diffusion version to a given reference with an operation. Compares the current k-diffusion version to a given reference with an operation.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment