Unverified Commit e8da75df authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[bitsandbytes] allow directly CUDA placements of pipelines loaded with bnb components (#9840)



* allow device placement when using bnb quantization.

* warning.

* tests

* fixes

* docs.

* require accelerate version.

* remove print.

* revert to()

* tests

* fixes

* fix: missing AutoencoderKL lora adapter (#9807)

* fix: missing AutoencoderKL lora adapter

* fix

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* fixes

* fix condition test

* updates

* updates

* remove is_offloaded.

* fixes

* better

* empty

---------
Co-authored-by: default avatarEmmanuel Benazera <emmanuel.benazera@jolibrain.com>
parent 8a450c3d
...@@ -66,7 +66,6 @@ from ..utils.torch_utils import is_compiled_module ...@@ -66,7 +66,6 @@ from ..utils.torch_utils import is_compiled_module
if is_torch_npu_available(): if is_torch_npu_available():
import torch_npu # noqa: F401 import torch_npu # noqa: F401
from .pipeline_loading_utils import ( from .pipeline_loading_utils import (
ALL_IMPORTABLE_CLASSES, ALL_IMPORTABLE_CLASSES,
CONNECTED_PIPES_KEYS, CONNECTED_PIPES_KEYS,
...@@ -388,6 +387,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -388,6 +387,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
) )
device = device or device_arg device = device or device_arg
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
def module_is_sequentially_offloaded(module): def module_is_sequentially_offloaded(module):
...@@ -410,10 +410,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -410,10 +410,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
pipeline_is_sequentially_offloaded = any( pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items() module_is_sequentially_offloaded(module) for _, module in self.components.items()
) )
if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda": if device and torch.device(device).type == "cuda":
raise ValueError( if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." raise ValueError(
) "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
)
# PR: https://github.com/huggingface/accelerate/pull/3223/
elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
raise ValueError(
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
)
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped: if is_pipeline_device_mapped:
......
...@@ -18,10 +18,11 @@ import tempfile ...@@ -18,10 +18,11 @@ import tempfile
import unittest import unittest
import numpy as np import numpy as np
import pytest
import safetensors.torch import safetensors.torch
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
from diffusers.utils import logging from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger, CaptureLogger,
is_bitsandbytes_available, is_bitsandbytes_available,
...@@ -47,6 +48,7 @@ def get_some_linear_layer(model): ...@@ -47,6 +48,7 @@ def get_some_linear_layer(model):
if is_transformers_available(): if is_transformers_available():
from transformers import BitsAndBytesConfig as BnbConfig
from transformers import T5EncoderModel from transformers import T5EncoderModel
if is_torch_available(): if is_torch_available():
...@@ -483,6 +485,47 @@ class SlowBnb4BitTests(Base4bitTests): ...@@ -483,6 +485,47 @@ class SlowBnb4BitTests(Base4bitTests):
assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out
@pytest.mark.xfail(
condition=is_accelerate_version("<=", "1.1.1"),
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
strict=True,
)
def test_pipeline_cuda_placement_works_with_nf4(self):
transformer_nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
transformer_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name,
subfolder="transformer",
quantization_config=transformer_nf4_config,
torch_dtype=torch.float16,
)
text_encoder_3_nf4_config = BnbConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
text_encoder_3_4bit = T5EncoderModel.from_pretrained(
self.model_name,
subfolder="text_encoder_3",
quantization_config=text_encoder_3_nf4_config,
torch_dtype=torch.float16,
)
# CUDA device placement works.
pipeline_4bit = DiffusionPipeline.from_pretrained(
self.model_name,
transformer=transformer_4bit,
text_encoder_3=text_encoder_3_4bit,
torch_dtype=torch.float16,
).to("cuda")
# Check if inference works.
_ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2)
del pipeline_4bit
@require_transformers_version_greater("4.44.0") @require_transformers_version_greater("4.44.0")
class SlowBnb4BitFluxTests(Base4bitTests): class SlowBnb4BitFluxTests(Base4bitTests):
......
...@@ -17,8 +17,10 @@ import tempfile ...@@ -17,8 +17,10 @@ import tempfile
import unittest import unittest
import numpy as np import numpy as np
import pytest
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
from diffusers.utils import is_accelerate_version
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger, CaptureLogger,
is_bitsandbytes_available, is_bitsandbytes_available,
...@@ -44,6 +46,7 @@ def get_some_linear_layer(model): ...@@ -44,6 +46,7 @@ def get_some_linear_layer(model):
if is_transformers_available(): if is_transformers_available():
from transformers import BitsAndBytesConfig as BnbConfig
from transformers import T5EncoderModel from transformers import T5EncoderModel
if is_torch_available(): if is_torch_available():
...@@ -432,6 +435,39 @@ class SlowBnb8bitTests(Base8bitTests): ...@@ -432,6 +435,39 @@ class SlowBnb8bitTests(Base8bitTests):
output_type="np", output_type="np",
).images ).images
@pytest.mark.xfail(
condition=is_accelerate_version("<=", "1.1.1"),
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
strict=True,
)
def test_pipeline_cuda_placement_works_with_mixed_int8(self):
transformer_8bit_config = BitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name,
subfolder="transformer",
quantization_config=transformer_8bit_config,
torch_dtype=torch.float16,
)
text_encoder_3_8bit_config = BnbConfig(load_in_8bit=True)
text_encoder_3_8bit = T5EncoderModel.from_pretrained(
self.model_name,
subfolder="text_encoder_3",
quantization_config=text_encoder_3_8bit_config,
torch_dtype=torch.float16,
)
# CUDA device placement works.
pipeline_8bit = DiffusionPipeline.from_pretrained(
self.model_name,
transformer=transformer_8bit,
text_encoder_3=text_encoder_3_8bit,
torch_dtype=torch.float16,
).to("cuda")
# Check if inference works.
_ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2)
del pipeline_8bit
@require_transformers_version_greater("4.44.0") @require_transformers_version_greater("4.44.0")
class SlowBnb8bitFluxTests(Base8bitTests): class SlowBnb8bitFluxTests(Base8bitTests):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment