Unverified Commit 60ffa842 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[bitsandbbytes] follow-ups (#9730)

* bnb follow ups.

* add a warning when dtypes mismatch.

* fx-copies

* clear cache.

* check_if_quantized_param

* add a check on shape.

* updates

* docs

* improve readability.

* resources.

* fix
parent 0f079b93
......@@ -59,19 +59,7 @@ model_8bit = FluxTransformer2DModel.from_pretrained(
model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype
```
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights.
```py
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config
)
```
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].
</hfoption>
<hfoption id="4-bit">
......@@ -131,7 +119,7 @@ from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = FluxTransformer2DModel.from_pretrained(
"sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer"
"hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer"
)
```
......@@ -264,4 +252,9 @@ double_quant_model = SD3Transformer2DModel.from_pretrained(
quantization_config=double_quant_config,
)
model.dequantize()
```
\ No newline at end of file
```
## Resources
* [End-to-end notebook showing Flux.1 Dev inference in a free-tier Colab](https://gist.github.com/sayakpaul/c76bd845b48759e11687ac550b99d8b4)
* [Training](https://gist.github.com/sayakpaul/05afd428bc089b47af7c016e42004527)
\ No newline at end of file
......@@ -211,21 +211,28 @@ def load_model_dict_into_meta(
set_module_kwargs["dtype"] = dtype
# bnb params are flattened.
if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
if empty_state_dict[param_name].shape != param.shape:
if (
is_quant_method_bnb
and hf_quantizer.pre_quantized
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
elif not is_quant_method_bnb:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
if not is_quantized or (
not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device)
if is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
else:
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
return unexpected_keys
......
......@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .auto import DiffusersAutoQuantizationConfig, DiffusersAutoQuantizer
from .auto import DiffusersAutoQuantizer
from .base import DiffusersQuantizer
......@@ -33,10 +33,10 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
}
class DiffusersAutoQuantizationConfig:
class DiffusersAutoQuantizer:
"""
The auto diffusers quantization config class that takes care of automatically dispatching to the correct
quantization config given a quantization config stored in a dictionary.
The auto diffusers quantizer class that takes care of automatically instantiating to the correct
`DiffusersQuantizer` given the `QuantizationConfig`.
"""
@classmethod
......@@ -60,31 +60,11 @@ class DiffusersAutoQuantizationConfig:
target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
return target_cls.from_dict(quantization_config_dict)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
model_config = cls.load_config(pretrained_model_name_or_path, **kwargs)
if getattr(model_config, "quantization_config", None) is None:
raise ValueError(
f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
)
quantization_config_dict = model_config.quantization_config
quantization_config = cls.from_dict(quantization_config_dict)
# Update with potential kwargs that are passed through from_pretrained.
quantization_config.update(kwargs)
return quantization_config
class DiffusersAutoQuantizer:
"""
The auto diffusers quantizer class that takes care of automatically instantiating to the correct
`DiffusersQuantizer` given the `QuantizationConfig`.
"""
@classmethod
def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs):
# Convert it to a QuantizationConfig if the q_config is a dict
if isinstance(quantization_config, dict):
quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config)
quantization_config = cls.from_dict(quantization_config)
quant_method = quantization_config.quant_method
......@@ -107,7 +87,16 @@ class DiffusersAutoQuantizer:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
quantization_config = DiffusersAutoQuantizationConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model_config = cls.load_config(pretrained_model_name_or_path, **kwargs)
if getattr(model_config, "quantization_config", None) is None:
raise ValueError(
f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
)
quantization_config_dict = model_config.quantization_config
quantization_config = cls.from_dict(quantization_config_dict)
# Update with potential kwargs that are passed through from_pretrained.
quantization_config.update(kwargs)
return cls.from_config(quantization_config)
@classmethod
......@@ -129,7 +118,7 @@ class DiffusersAutoQuantizer:
warning_msg = ""
if isinstance(quantization_config, dict):
quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config)
quantization_config = cls.from_dict(quantization_config)
if warning_msg != "":
warnings.warn(warning_msg)
......
......@@ -134,7 +134,7 @@ class DiffusersQuantizer(ABC):
"""adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization"""
return max_memory
def check_quantized_param(
def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
......@@ -152,10 +152,13 @@ class DiffusersQuantizer(ABC):
"""
takes needed components from state_dict and creates quantized param.
"""
if not hasattr(self, "check_quantized_param"):
raise AttributeError(
f"`.create_quantized_param()` method is not supported by quantizer class {self.__class__.__name__}."
)
return
def check_quantized_param_shape(self, *args, **kwargs):
"""
checks if the quantized param has expected shape.
"""
return True
def validate_environment(self, *args, **kwargs):
"""
......
......@@ -106,7 +106,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
else:
raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.")
def check_quantized_param(
def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
......@@ -204,6 +204,16 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
module._parameters[tensor_name] = new_value
def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape):
n = current_param_shape.numel()
inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
if loaded_param_shape != inferred_shape:
raise ValueError(
f"Expected the flattened shape of the current param ({param_name}) to be {loaded_param_shape} but is {inferred_shape}."
)
else:
return True
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
# need more space for buffers that are created during quantization
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
......@@ -330,7 +340,6 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
if self.quantization_config.llm_int8_skip_modules is not None:
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.validate_environment with 4-bit->8-bit
def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
......@@ -404,7 +413,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization")
return torch.int8
def check_quantized_param(
def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
......
......@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import tempfile
import unittest
import numpy as np
import safetensors.torch
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
from diffusers.utils import logging
......@@ -118,6 +120,9 @@ class Base4bitTests(unittest.TestCase):
class BnB4BitBasicTests(Base4bitTests):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()
# Models
self.model_fp16 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", torch_dtype=torch.float16
......@@ -232,7 +237,7 @@ class BnB4BitBasicTests(Base4bitTests):
def test_config_from_pretrained(self):
transformer_4bit = FluxTransformer2DModel.from_pretrained(
"sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer"
"hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer"
)
linear = get_some_linear_layer(transformer_4bit)
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
......@@ -312,9 +317,42 @@ class BnB4BitBasicTests(Base4bitTests):
with self.assertRaises(ValueError):
_ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add")
def test_bnb_4bit_errors_loading_incorrect_state_dict(self):
r"""
Test if loading with an incorrect state dict raises an error.
"""
with tempfile.TemporaryDirectory() as tmpdirname:
nf4_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config
)
model_4bit.save_pretrained(tmpdirname)
del model_4bit
with self.assertRaises(ValueError) as err_context:
state_dict = safetensors.torch.load_file(
os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors")
)
# corrupt the state dict
key_to_target = "context_embedder.weight" # can be other keys too.
compatible_param = state_dict[key_to_target]
corrupted_param = torch.randn(compatible_param.shape[0] - 1, 1)
state_dict[key_to_target] = bnb.nn.Params4bit(corrupted_param, requires_grad=False)
safetensors.torch.save_file(
state_dict, os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors")
)
_ = SD3Transformer2DModel.from_pretrained(tmpdirname)
assert key_to_target in str(err_context.exception)
class BnB4BitTrainingTests(Base4bitTests):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
......@@ -360,6 +398,9 @@ class BnB4BitTrainingTests(Base4bitTests):
@require_transformers_version_greater("4.44.0")
class SlowBnb4BitTests(Base4bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
......@@ -447,8 +488,10 @@ class SlowBnb4BitTests(Base4bitTests):
@require_transformers_version_greater("4.44.0")
class SlowBnb4BitFluxTests(Base4bitTests):
def setUp(self) -> None:
# TODO: Copy sayakpaul/flux.1-dev-nf4-pkg to testing repo.
model_id = "sayakpaul/flux.1-dev-nf4-pkg"
gc.collect()
torch.cuda.empty_cache()
model_id = "hf-internal-testing/flux.1-dev-nf4-pkg"
t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
self.pipeline_4bit = DiffusionPipeline.from_pretrained(
......
......@@ -117,6 +117,9 @@ class Base8bitTests(unittest.TestCase):
class BnB8bitBasicTests(Base8bitTests):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()
# Models
self.model_fp16 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", torch_dtype=torch.float16
......@@ -238,7 +241,7 @@ class BnB8bitBasicTests(Base8bitTests):
def test_config_from_pretrained(self):
transformer_8bit = FluxTransformer2DModel.from_pretrained(
"sayakpaul/flux.1-dev-int8-pkg", subfolder="transformer"
"hf-internal-testing/flux.1-dev-int8-pkg", subfolder="transformer"
)
linear = get_some_linear_layer(transformer_8bit)
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
......@@ -296,6 +299,9 @@ class BnB8bitBasicTests(Base8bitTests):
class BnB8bitTrainingTests(Base8bitTests):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
......@@ -337,6 +343,9 @@ class BnB8bitTrainingTests(Base8bitTests):
@require_transformers_version_greater("4.44.0")
class SlowBnb8bitTests(Base8bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
......@@ -427,8 +436,10 @@ class SlowBnb8bitTests(Base8bitTests):
@require_transformers_version_greater("4.44.0")
class SlowBnb8bitFluxTests(Base8bitTests):
def setUp(self) -> None:
# TODO: Copy sayakpaul/flux.1-dev-int8-pkg to testing repo.
model_id = "sayakpaul/flux.1-dev-int8-pkg"
gc.collect()
torch.cuda.empty_cache()
model_id = "hf-internal-testing/flux.1-dev-int8-pkg"
t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
transformer_8bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
self.pipeline_8bit = DiffusionPipeline.from_pretrained(
......@@ -466,6 +477,9 @@ class SlowBnb8bitFluxTests(Base8bitTests):
@slow
class BaseBnb8bitSerializationTests(Base8bitTests):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment