Unverified Commit cd991d1e authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Fix TorchAO related bugs; revert device_map changes (#10371)

* Revert "Add support for sharded models when TorchAO quantization is enabled (#10256)"

This reverts commit 41ba8c0b

.

* update tests

* udpate

* update

* update

* update device map tests

* apply review suggestions

* update

* make style

* fix

* update docs

* update tests

* update workflow

* update

* improve tests

* allclose tolerance

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update tests/quantization/torchao/test_torchao.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* improve tests

* fix

* update correct slices

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 825979dd
...@@ -359,6 +359,8 @@ jobs: ...@@ -359,6 +359,8 @@ jobs:
test_location: "bnb" test_location: "bnb"
- backend: "gguf" - backend: "gguf"
test_location: "gguf" test_location: "gguf"
- backend: "torchao"
test_location: "torchao"
runs-on: runs-on:
group: aws-g6e-xlarge-plus group: aws-g6e-xlarge-plus
container: container:
......
...@@ -25,6 +25,7 @@ Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] ...@@ -25,6 +25,7 @@ Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]
The example below only quantizes the weights to int8. The example below only quantizes the weights to int8.
```python ```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
model_id = "black-forest-labs/FLUX.1-dev" model_id = "black-forest-labs/FLUX.1-dev"
...@@ -44,6 +45,10 @@ pipe = FluxPipeline.from_pretrained( ...@@ -44,6 +45,10 @@ pipe = FluxPipeline.from_pretrained(
) )
pipe.to("cuda") pipe.to("cuda")
# Without quantization: ~31.447 GB
# With quantization: ~20.40 GB
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
prompt = "A cat holding a sign that says hello world" prompt = "A cat holding a sign that says hello world"
image = pipe( image = pipe(
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
...@@ -88,6 +93,63 @@ Some quantization methods are aliases (for example, `int8wo` is the commonly use ...@@ -88,6 +93,63 @@ Some quantization methods are aliases (for example, `int8wo` is the commonly use
Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available. Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
## Serializing and Deserializing quantized models
To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.
```python
import torch
from diffusers import FluxTransformer2DModel, TorchAoConfig
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False)
```
To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.
```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
image.save("output.png")
```
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
```python
import torch
from accelerate import init_empty_weights
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
# Serialize the model
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=TorchAoConfig("uint4wo"),
torch_dtype=torch.bfloat16,
)
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
# ...
# Load the model
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
with init_empty_weights():
transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
transformer.load_state_dict(state_dict, strict=True, assign=True)
```
## Resources ## Resources
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md) - [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
......
...@@ -718,10 +718,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -718,10 +718,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
hf_quantizer = None hf_quantizer = None
if hf_quantizer is not None: if hf_quantizer is not None:
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes" if device_map is not None:
if is_bnb_quantization_method and device_map is not None:
raise NotImplementedError( raise NotImplementedError(
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future." "Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future."
) )
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
...@@ -820,7 +819,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -820,7 +819,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
revision=revision, revision=revision,
subfolder=subfolder or "", subfolder=subfolder or "",
) )
if hf_quantizer is not None and is_bnb_quantization_method: # TODO: https://github.com/huggingface/diffusers/issues/10013
if hf_quantizer is not None:
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False is_sharded = False
......
...@@ -132,7 +132,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer): ...@@ -132,7 +132,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
def update_torch_dtype(self, torch_dtype): def update_torch_dtype(self, torch_dtype):
quant_type = self.quantization_config.quant_type quant_type = self.quantization_config.quant_type
if quant_type.startswith("int"): if quant_type.startswith("int") or quant_type.startswith("uint"):
if torch_dtype is not None and torch_dtype != torch.bfloat16: if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning( logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but " f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment