Unverified Commit 60ffa842 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[bitsandbbytes] follow-ups (#9730)

* bnb follow ups.

* add a warning when dtypes mismatch.

* fx-copies

* clear cache.

* check_if_quantized_param

* add a check on shape.

* updates

* docs

* improve readability.

* resources.

* fix
parent 0f079b93
...@@ -59,19 +59,7 @@ model_8bit = FluxTransformer2DModel.from_pretrained( ...@@ -59,19 +59,7 @@ model_8bit = FluxTransformer2DModel.from_pretrained(
model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype
``` ```
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].
```py
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config
)
```
</hfoption> </hfoption>
<hfoption id="4-bit"> <hfoption id="4-bit">
...@@ -131,7 +119,7 @@ from diffusers import FluxTransformer2DModel, BitsAndBytesConfig ...@@ -131,7 +119,7 @@ from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True) quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = FluxTransformer2DModel.from_pretrained( model_4bit = FluxTransformer2DModel.from_pretrained(
"sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer" "hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer"
) )
``` ```
...@@ -264,4 +252,9 @@ double_quant_model = SD3Transformer2DModel.from_pretrained( ...@@ -264,4 +252,9 @@ double_quant_model = SD3Transformer2DModel.from_pretrained(
quantization_config=double_quant_config, quantization_config=double_quant_config,
) )
model.dequantize() model.dequantize()
``` ```
\ No newline at end of file
## Resources
* [End-to-end notebook showing Flux.1 Dev inference in a free-tier Colab](https://gist.github.com/sayakpaul/c76bd845b48759e11687ac550b99d8b4)
* [Training](https://gist.github.com/sayakpaul/05afd428bc089b47af7c016e42004527)
\ No newline at end of file
...@@ -211,21 +211,28 @@ def load_model_dict_into_meta( ...@@ -211,21 +211,28 @@ def load_model_dict_into_meta(
set_module_kwargs["dtype"] = dtype set_module_kwargs["dtype"] = dtype
# bnb params are flattened. # bnb params are flattened.
if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape: if empty_state_dict[param_name].shape != param.shape:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" if (
raise ValueError( is_quant_method_bnb
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." and hf_quantizer.pre_quantized
) and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
elif not is_quant_method_bnb:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
if not is_quantized or ( if is_quantized and (
not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device) hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
): ):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
else:
if accepts_dtype: if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs) set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else: else:
set_module_tensor_to_device(model, param_name, device, value=param) set_module_tensor_to_device(model, param_name, device, value=param)
else:
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
return unexpected_keys return unexpected_keys
......
...@@ -12,5 +12,5 @@ ...@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .auto import DiffusersAutoQuantizationConfig, DiffusersAutoQuantizer from .auto import DiffusersAutoQuantizer
from .base import DiffusersQuantizer from .base import DiffusersQuantizer
...@@ -33,10 +33,10 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = { ...@@ -33,10 +33,10 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
} }
class DiffusersAutoQuantizationConfig: class DiffusersAutoQuantizer:
""" """
The auto diffusers quantization config class that takes care of automatically dispatching to the correct The auto diffusers quantizer class that takes care of automatically instantiating to the correct
quantization config given a quantization config stored in a dictionary. `DiffusersQuantizer` given the `QuantizationConfig`.
""" """
@classmethod @classmethod
...@@ -60,31 +60,11 @@ class DiffusersAutoQuantizationConfig: ...@@ -60,31 +60,11 @@ class DiffusersAutoQuantizationConfig:
target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method] target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
return target_cls.from_dict(quantization_config_dict) return target_cls.from_dict(quantization_config_dict)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
model_config = cls.load_config(pretrained_model_name_or_path, **kwargs)
if getattr(model_config, "quantization_config", None) is None:
raise ValueError(
f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
)
quantization_config_dict = model_config.quantization_config
quantization_config = cls.from_dict(quantization_config_dict)
# Update with potential kwargs that are passed through from_pretrained.
quantization_config.update(kwargs)
return quantization_config
class DiffusersAutoQuantizer:
"""
The auto diffusers quantizer class that takes care of automatically instantiating to the correct
`DiffusersQuantizer` given the `QuantizationConfig`.
"""
@classmethod @classmethod
def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs): def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs):
# Convert it to a QuantizationConfig if the q_config is a dict # Convert it to a QuantizationConfig if the q_config is a dict
if isinstance(quantization_config, dict): if isinstance(quantization_config, dict):
quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config) quantization_config = cls.from_dict(quantization_config)
quant_method = quantization_config.quant_method quant_method = quantization_config.quant_method
...@@ -107,7 +87,16 @@ class DiffusersAutoQuantizer: ...@@ -107,7 +87,16 @@ class DiffusersAutoQuantizer:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
quantization_config = DiffusersAutoQuantizationConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) model_config = cls.load_config(pretrained_model_name_or_path, **kwargs)
if getattr(model_config, "quantization_config", None) is None:
raise ValueError(
f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
)
quantization_config_dict = model_config.quantization_config
quantization_config = cls.from_dict(quantization_config_dict)
# Update with potential kwargs that are passed through from_pretrained.
quantization_config.update(kwargs)
return cls.from_config(quantization_config) return cls.from_config(quantization_config)
@classmethod @classmethod
...@@ -129,7 +118,7 @@ class DiffusersAutoQuantizer: ...@@ -129,7 +118,7 @@ class DiffusersAutoQuantizer:
warning_msg = "" warning_msg = ""
if isinstance(quantization_config, dict): if isinstance(quantization_config, dict):
quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config) quantization_config = cls.from_dict(quantization_config)
if warning_msg != "": if warning_msg != "":
warnings.warn(warning_msg) warnings.warn(warning_msg)
......
...@@ -134,7 +134,7 @@ class DiffusersQuantizer(ABC): ...@@ -134,7 +134,7 @@ class DiffusersQuantizer(ABC):
"""adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization""" """adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization"""
return max_memory return max_memory
def check_quantized_param( def check_if_quantized_param(
self, self,
model: "ModelMixin", model: "ModelMixin",
param_value: "torch.Tensor", param_value: "torch.Tensor",
...@@ -152,10 +152,13 @@ class DiffusersQuantizer(ABC): ...@@ -152,10 +152,13 @@ class DiffusersQuantizer(ABC):
""" """
takes needed components from state_dict and creates quantized param. takes needed components from state_dict and creates quantized param.
""" """
if not hasattr(self, "check_quantized_param"): return
raise AttributeError(
f"`.create_quantized_param()` method is not supported by quantizer class {self.__class__.__name__}." def check_quantized_param_shape(self, *args, **kwargs):
) """
checks if the quantized param has expected shape.
"""
return True
def validate_environment(self, *args, **kwargs): def validate_environment(self, *args, **kwargs):
""" """
......
...@@ -106,7 +106,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): ...@@ -106,7 +106,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
else: else:
raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.") raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.")
def check_quantized_param( def check_if_quantized_param(
self, self,
model: "ModelMixin", model: "ModelMixin",
param_value: "torch.Tensor", param_value: "torch.Tensor",
...@@ -204,6 +204,16 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer): ...@@ -204,6 +204,16 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
module._parameters[tensor_name] = new_value module._parameters[tensor_name] = new_value
def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape):
n = current_param_shape.numel()
inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
if loaded_param_shape != inferred_shape:
raise ValueError(
f"Expected the flattened shape of the current param ({param_name}) to be {loaded_param_shape} but is {inferred_shape}."
)
else:
return True
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
# need more space for buffers that are created during quantization # need more space for buffers that are created during quantization
max_memory = {key: val * 0.90 for key, val in max_memory.items()} max_memory = {key: val * 0.90 for key, val in max_memory.items()}
...@@ -330,7 +340,6 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer): ...@@ -330,7 +340,6 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
if self.quantization_config.llm_int8_skip_modules is not None: if self.quantization_config.llm_int8_skip_modules is not None:
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.validate_environment with 4-bit->8-bit
def validate_environment(self, *args, **kwargs): def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.") raise RuntimeError("No GPU found. A GPU is needed for quantization.")
...@@ -404,7 +413,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer): ...@@ -404,7 +413,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization") logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization")
return torch.int8 return torch.int8
def check_quantized_param( def check_if_quantized_param(
self, self,
model: "ModelMixin", model: "ModelMixin",
param_value: "torch.Tensor", param_value: "torch.Tensor",
......
...@@ -13,10 +13,12 @@ ...@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc import gc
import os
import tempfile import tempfile
import unittest import unittest
import numpy as np import numpy as np
import safetensors.torch
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
from diffusers.utils import logging from diffusers.utils import logging
...@@ -118,6 +120,9 @@ class Base4bitTests(unittest.TestCase): ...@@ -118,6 +120,9 @@ class Base4bitTests(unittest.TestCase):
class BnB4BitBasicTests(Base4bitTests): class BnB4BitBasicTests(Base4bitTests):
def setUp(self): def setUp(self):
gc.collect()
torch.cuda.empty_cache()
# Models # Models
self.model_fp16 = SD3Transformer2DModel.from_pretrained( self.model_fp16 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", torch_dtype=torch.float16 self.model_name, subfolder="transformer", torch_dtype=torch.float16
...@@ -232,7 +237,7 @@ class BnB4BitBasicTests(Base4bitTests): ...@@ -232,7 +237,7 @@ class BnB4BitBasicTests(Base4bitTests):
def test_config_from_pretrained(self): def test_config_from_pretrained(self):
transformer_4bit = FluxTransformer2DModel.from_pretrained( transformer_4bit = FluxTransformer2DModel.from_pretrained(
"sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer" "hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer"
) )
linear = get_some_linear_layer(transformer_4bit) linear = get_some_linear_layer(transformer_4bit)
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
...@@ -312,9 +317,42 @@ class BnB4BitBasicTests(Base4bitTests): ...@@ -312,9 +317,42 @@ class BnB4BitBasicTests(Base4bitTests):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add") _ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add")
def test_bnb_4bit_errors_loading_incorrect_state_dict(self):
r"""
Test if loading with an incorrect state dict raises an error.
"""
with tempfile.TemporaryDirectory() as tmpdirname:
nf4_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config
)
model_4bit.save_pretrained(tmpdirname)
del model_4bit
with self.assertRaises(ValueError) as err_context:
state_dict = safetensors.torch.load_file(
os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors")
)
# corrupt the state dict
key_to_target = "context_embedder.weight" # can be other keys too.
compatible_param = state_dict[key_to_target]
corrupted_param = torch.randn(compatible_param.shape[0] - 1, 1)
state_dict[key_to_target] = bnb.nn.Params4bit(corrupted_param, requires_grad=False)
safetensors.torch.save_file(
state_dict, os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors")
)
_ = SD3Transformer2DModel.from_pretrained(tmpdirname)
assert key_to_target in str(err_context.exception)
class BnB4BitTrainingTests(Base4bitTests): class BnB4BitTrainingTests(Base4bitTests):
def setUp(self): def setUp(self):
gc.collect()
torch.cuda.empty_cache()
nf4_config = BitsAndBytesConfig( nf4_config = BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
bnb_4bit_quant_type="nf4", bnb_4bit_quant_type="nf4",
...@@ -360,6 +398,9 @@ class BnB4BitTrainingTests(Base4bitTests): ...@@ -360,6 +398,9 @@ class BnB4BitTrainingTests(Base4bitTests):
@require_transformers_version_greater("4.44.0") @require_transformers_version_greater("4.44.0")
class SlowBnb4BitTests(Base4bitTests): class SlowBnb4BitTests(Base4bitTests):
def setUp(self) -> None: def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
nf4_config = BitsAndBytesConfig( nf4_config = BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
bnb_4bit_quant_type="nf4", bnb_4bit_quant_type="nf4",
...@@ -447,8 +488,10 @@ class SlowBnb4BitTests(Base4bitTests): ...@@ -447,8 +488,10 @@ class SlowBnb4BitTests(Base4bitTests):
@require_transformers_version_greater("4.44.0") @require_transformers_version_greater("4.44.0")
class SlowBnb4BitFluxTests(Base4bitTests): class SlowBnb4BitFluxTests(Base4bitTests):
def setUp(self) -> None: def setUp(self) -> None:
# TODO: Copy sayakpaul/flux.1-dev-nf4-pkg to testing repo. gc.collect()
model_id = "sayakpaul/flux.1-dev-nf4-pkg" torch.cuda.empty_cache()
model_id = "hf-internal-testing/flux.1-dev-nf4-pkg"
t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer") transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
self.pipeline_4bit = DiffusionPipeline.from_pretrained( self.pipeline_4bit = DiffusionPipeline.from_pretrained(
......
...@@ -117,6 +117,9 @@ class Base8bitTests(unittest.TestCase): ...@@ -117,6 +117,9 @@ class Base8bitTests(unittest.TestCase):
class BnB8bitBasicTests(Base8bitTests): class BnB8bitBasicTests(Base8bitTests):
def setUp(self): def setUp(self):
gc.collect()
torch.cuda.empty_cache()
# Models # Models
self.model_fp16 = SD3Transformer2DModel.from_pretrained( self.model_fp16 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", torch_dtype=torch.float16 self.model_name, subfolder="transformer", torch_dtype=torch.float16
...@@ -238,7 +241,7 @@ class BnB8bitBasicTests(Base8bitTests): ...@@ -238,7 +241,7 @@ class BnB8bitBasicTests(Base8bitTests):
def test_config_from_pretrained(self): def test_config_from_pretrained(self):
transformer_8bit = FluxTransformer2DModel.from_pretrained( transformer_8bit = FluxTransformer2DModel.from_pretrained(
"sayakpaul/flux.1-dev-int8-pkg", subfolder="transformer" "hf-internal-testing/flux.1-dev-int8-pkg", subfolder="transformer"
) )
linear = get_some_linear_layer(transformer_8bit) linear = get_some_linear_layer(transformer_8bit)
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
...@@ -296,6 +299,9 @@ class BnB8bitBasicTests(Base8bitTests): ...@@ -296,6 +299,9 @@ class BnB8bitBasicTests(Base8bitTests):
class BnB8bitTrainingTests(Base8bitTests): class BnB8bitTrainingTests(Base8bitTests):
def setUp(self): def setUp(self):
gc.collect()
torch.cuda.empty_cache()
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SD3Transformer2DModel.from_pretrained( self.model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
...@@ -337,6 +343,9 @@ class BnB8bitTrainingTests(Base8bitTests): ...@@ -337,6 +343,9 @@ class BnB8bitTrainingTests(Base8bitTests):
@require_transformers_version_greater("4.44.0") @require_transformers_version_greater("4.44.0")
class SlowBnb8bitTests(Base8bitTests): class SlowBnb8bitTests(Base8bitTests):
def setUp(self) -> None: def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = SD3Transformer2DModel.from_pretrained( model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
...@@ -427,8 +436,10 @@ class SlowBnb8bitTests(Base8bitTests): ...@@ -427,8 +436,10 @@ class SlowBnb8bitTests(Base8bitTests):
@require_transformers_version_greater("4.44.0") @require_transformers_version_greater("4.44.0")
class SlowBnb8bitFluxTests(Base8bitTests): class SlowBnb8bitFluxTests(Base8bitTests):
def setUp(self) -> None: def setUp(self) -> None:
# TODO: Copy sayakpaul/flux.1-dev-int8-pkg to testing repo. gc.collect()
model_id = "sayakpaul/flux.1-dev-int8-pkg" torch.cuda.empty_cache()
model_id = "hf-internal-testing/flux.1-dev-int8-pkg"
t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
transformer_8bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer") transformer_8bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
self.pipeline_8bit = DiffusionPipeline.from_pretrained( self.pipeline_8bit = DiffusionPipeline.from_pretrained(
...@@ -466,6 +477,9 @@ class SlowBnb8bitFluxTests(Base8bitTests): ...@@ -466,6 +477,9 @@ class SlowBnb8bitFluxTests(Base8bitTests):
@slow @slow
class BaseBnb8bitSerializationTests(Base8bitTests): class BaseBnb8bitSerializationTests(Base8bitTests):
def setUp(self): def setUp(self):
gc.collect()
torch.cuda.empty_cache()
quantization_config = BitsAndBytesConfig( quantization_config = BitsAndBytesConfig(
load_in_8bit=True, load_in_8bit=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment