Unverified Commit beacaa55 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] Layerwise Upcasting (#10347)



* update

* update

* make style

* remove dynamo disable

* add coauthor
Co-Authored-By: default avatarDhruv Nair <dhruv.nair@gmail.com>

* update

* update

* update

* update mixin

* add some basic tests

* update

* update

* non_blocking

* improvements

* update

* norm.* -> norm

* apply suggestions from review

* add example

* update hook implementation to the latest changes from pyramid attention broadcast

* deinitialize should raise an error

* update doc page

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update docs

* update

* refactor

* fix _always_upcast_modules for asym ae and vq_model

* fix lumina embedding forward to not depend on weight dtype

* refactor tests

* add simple lora inference tests

* _always_upcast_modules -> _precision_sensitive_module_patterns

* remove todo comments about review; revert changes to self.dtype in unets because .dtype on ModelMixin should be able to handle fp8 weight case

* check layer dtypes in lora test

* fix UNet1DModelTests::test_layerwise_upcasting_inference

* _precision_sensitive_module_patterns -> _skip_layerwise_casting_patterns based on feedback

* skip test in NCSNppModelTests

* skip tests for AutoencoderTinyTests

* skip tests for AutoencoderOobleckTests

* skip tests for UNet1DModelTests - unsupported pytorch operations

* layerwise_upcasting -> layerwise_casting

* skip tests for UNetRLModelTests; needs next pytorch release for currently unimplemented operation support

* add layerwise fp8 pipeline test

* use xfail

* Apply suggestions from code review
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* add assertion with fp32 comparison; add tolerance to fp8-fp32 vs fp32-fp32 comparison (required for a few models' test to pass)

* add note about memory consumption on tesla CI runner for failing test

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent a6476822
...@@ -41,3 +41,7 @@ Utility and helper functions for working with 🤗 Diffusers. ...@@ -41,3 +41,7 @@ Utility and helper functions for working with 🤗 Diffusers.
## randn_tensor ## randn_tensor
[[autodoc]] utils.torch_utils.randn_tensor [[autodoc]] utils.torch_utils.randn_tensor
## apply_layerwise_casting
[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting
...@@ -158,6 +158,43 @@ In order to properly offload models after they're called, it is required to run ...@@ -158,6 +158,43 @@ In order to properly offload models after they're called, it is required to run
</Tip> </Tip>
## FP8 layerwise weight-casting
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Layerwise weight-casting cuts down the memory footprint of the model weights by approximately half.
```python
import torch
from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel
from diffusers.utils import export_to_video
model_id = "THUDM/CogVideoX-5b"
# Load the model in bfloat16 and enable layerwise casting
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
# Load the pipeline
pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
```
In the above example, layerwise casting is enabled on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. The normalization and modulation related weight parameters are also skipped by default.
However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_casting.apply_layerwise_casting`] function instead of [`~ModelMixin.enable_layerwise_casting`].
## Channels-last memory format ## Channels-last memory format
The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model. The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.
......
from ..utils import is_torch_available
if is_torch_available():
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Any, Dict, Optional, Tuple
import torch
from ..utils.logging import get_logger
logger = get_logger(__name__) # pylint: disable=invalid-name
class ModelHook:
r"""
A hook that contains callbacks to be executed just before and after the forward method of a model.
"""
_is_stateful = False
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is initialized.
Args:
module (`torch.nn.Module`):
The module attached to this hook.
"""
return module
def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is deinitalized.
Args:
module (`torch.nn.Module`):
The module attached to this hook.
"""
module.forward = module._old_forward
del module._old_forward
return module
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
r"""
Hook that is executed just before the forward method of the model.
Args:
module (`torch.nn.Module`):
The module whose forward pass will be executed just after this event.
args (`Tuple[Any]`):
The positional arguments passed to the module.
kwargs (`Dict[Str, Any]`):
The keyword arguments passed to the module.
Returns:
`Tuple[Tuple[Any], Dict[Str, Any]]`:
A tuple with the treated `args` and `kwargs`.
"""
return args, kwargs
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
r"""
Hook that is executed just after the forward method of the model.
Args:
module (`torch.nn.Module`):
The module whose forward pass been executed just before this event.
output (`Any`):
The output of the module.
Returns:
`Any`: The processed `output`.
"""
return output
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when the hook is detached from a module.
Args:
module (`torch.nn.Module`):
The module detached from this hook.
"""
return module
def reset_state(self, module: torch.nn.Module):
if self._is_stateful:
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module
class HookRegistry:
def __init__(self, module_ref: torch.nn.Module) -> None:
super().__init__()
self.hooks: Dict[str, ModelHook] = {}
self._module_ref = module_ref
self._hook_order = []
def register_hook(self, hook: ModelHook, name: str) -> None:
if name in self.hooks.keys():
logger.warning(f"Hook with name {name} already exists, replacing it.")
if hasattr(self._module_ref, "_old_forward"):
old_forward = self._module_ref._old_forward
else:
old_forward = self._module_ref.forward
self._module_ref._old_forward = self._module_ref.forward
self._module_ref = hook.initialize_hook(self._module_ref)
if hasattr(hook, "new_forward"):
rewritten_forward = hook.new_forward
def new_forward(module, *args, **kwargs):
args, kwargs = hook.pre_forward(module, *args, **kwargs)
output = rewritten_forward(module, *args, **kwargs)
return hook.post_forward(module, output)
else:
def new_forward(module, *args, **kwargs):
args, kwargs = hook.pre_forward(module, *args, **kwargs)
output = old_forward(*args, **kwargs)
return hook.post_forward(module, output)
self._module_ref.forward = functools.update_wrapper(
functools.partial(new_forward, self._module_ref), old_forward
)
self.hooks[name] = hook
self._hook_order.append(name)
def get_hook(self, name: str) -> Optional[ModelHook]:
if name not in self.hooks.keys():
return None
return self.hooks[name]
def remove_hook(self, name: str, recurse: bool = True) -> None:
if name in self.hooks.keys():
hook = self.hooks[name]
self._module_ref = hook.deinitalize_hook(self._module_ref)
del self.hooks[name]
self._hook_order.remove(name)
if recurse:
for module_name, module in self._module_ref.named_modules():
if module_name == "":
continue
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.remove_hook(name, recurse=False)
def reset_stateful_hooks(self, recurse: bool = True) -> None:
for hook_name in self._hook_order:
hook = self.hooks[hook_name]
if hook._is_stateful:
hook.reset_state(self._module_ref)
if recurse:
for module_name, module in self._module_ref.named_modules():
if module_name == "":
continue
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.reset_stateful_hooks(recurse=False)
@classmethod
def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
if not hasattr(module, "_diffusers_hook"):
module._diffusers_hook = cls(module)
return module._diffusers_hook
def __repr__(self) -> str:
hook_repr = ""
for i, hook_name in enumerate(self._hook_order):
hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
if i < len(self._hook_order) - 1:
hook_repr += "\n"
return f"HookRegistry(\n{hook_repr}\n)"
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Optional, Tuple, Type, Union
import torch
from ..utils import get_logger
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
# fmt: off
SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
torch.nn.Linear,
)
DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
# fmt: on
class LayerwiseCastingHook(ModelHook):
r"""
A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
footprint.
"""
_is_stateful = False
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
self.storage_dtype = storage_dtype
self.compute_dtype = compute_dtype
self.non_blocking = non_blocking
def initialize_hook(self, module: torch.nn.Module):
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
return module
def deinitalize_hook(self, module: torch.nn.Module):
raise NotImplementedError(
"LayerwiseCastingHook does not support deinitalization. A model once enabled with layerwise casting will "
"have casted its weights to a lower precision dtype for storage. Casting this back to the original dtype "
"will lead to precision loss, which might have an impact on the model's generation quality. The model should "
"be re-initialized and loaded in the original dtype."
)
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
return args, kwargs
def post_forward(self, module: torch.nn.Module, output):
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
return output
def apply_layerwise_casting(
module: torch.nn.Module,
storage_dtype: torch.dtype,
compute_dtype: torch.dtype,
skip_modules_pattern: Union[str, Tuple[str, ...]] = "auto",
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
non_blocking: bool = False,
) -> None:
r"""
Applies layerwise casting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
nn.Module using diffusers layers or pytorch primitives.
Example:
```python
>>> import torch
>>> from diffusers import CogVideoXTransformer3DModel
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
... )
>>> apply_layerwise_casting(
... transformer,
... storage_dtype=torch.float8_e4m3fn,
... compute_dtype=torch.bfloat16,
... skip_modules_pattern=["patch_embed", "norm", "proj_out"],
... non_blocking=True,
... )
```
Args:
module (`torch.nn.Module`):
The module whose leaf modules will be cast to a high precision dtype for computation, and to a low
precision dtype for storage.
storage_dtype (`torch.dtype`):
The dtype to cast the module to before/after the forward pass for storage.
compute_dtype (`torch.dtype`):
The dtype to cast the module to during the forward pass for computation.
skip_modules_pattern (`Tuple[str, ...]`, defaults to `"auto"`):
A list of patterns to match the names of the modules to skip during the layerwise casting process. If set
to `"auto"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None`
alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the module
instead of its internal submodules.
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
A list of module classes to skip during the layerwise casting process.
non_blocking (`bool`, defaults to `False`):
If `True`, the weight casting operations are non-blocking.
"""
if skip_modules_pattern == "auto":
skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
if skip_modules_classes is None and skip_modules_pattern is None:
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
return
_apply_layerwise_casting(
module,
storage_dtype,
compute_dtype,
skip_modules_pattern,
skip_modules_classes,
non_blocking,
)
def _apply_layerwise_casting(
module: torch.nn.Module,
storage_dtype: torch.dtype,
compute_dtype: torch.dtype,
skip_modules_pattern: Optional[Tuple[str, ...]] = None,
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
non_blocking: bool = False,
_prefix: str = "",
) -> None:
should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
)
if should_skip:
logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
return
if isinstance(module, SUPPORTED_PYTORCH_LAYERS):
logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
return
for name, submodule in module.named_children():
layer_name = f"{_prefix}.{name}" if _prefix else name
_apply_layerwise_casting(
submodule,
storage_dtype,
compute_dtype,
skip_modules_pattern,
skip_modules_classes,
non_blocking,
_prefix=layer_name,
)
def apply_layerwise_casting_hook(
module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool
) -> None:
r"""
Applies a `LayerwiseCastingHook` to a given module.
Args:
module (`torch.nn.Module`):
The module to attach the hook to.
storage_dtype (`torch.dtype`):
The dtype to cast the module to before the forward pass.
compute_dtype (`torch.dtype`):
The dtype to cast the module to during the forward pass.
non_blocking (`bool`):
If `True`, the weight casting operations are non-blocking.
"""
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = LayerwiseCastingHook(storage_dtype, compute_dtype, non_blocking)
registry.register_hook(hook, "layerwise_casting")
...@@ -60,6 +60,8 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): ...@@ -60,6 +60,8 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
""" """
_skip_layerwise_casting_patterns = ["decoder"]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
...@@ -71,6 +71,8 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -71,6 +71,8 @@ class VQModel(ModelMixin, ConfigMixin):
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
""" """
_skip_layerwise_casting_patterns = ["quantize"]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
...@@ -1787,7 +1787,7 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module): ...@@ -1787,7 +1787,7 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
def forward(self, timestep, caption_feat, caption_mask): def forward(self, timestep, caption_feat, caption_mask):
# timestep embedding: # timestep embedding:
time_freq = self.time_proj(timestep) time_freq = self.time_proj(timestep)
time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype)) time_embed = self.timestep_embedder(time_freq.to(dtype=caption_feat.dtype))
# caption condition embedding: # caption condition embedding:
caption_mask_float = caption_mask.float().unsqueeze(-1) caption_mask_float = caption_mask.float().unsqueeze(-1)
......
...@@ -23,7 +23,7 @@ import re ...@@ -23,7 +23,7 @@ import re
from collections import OrderedDict from collections import OrderedDict
from functools import partial, wraps from functools import partial, wraps
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import safetensors import safetensors
import torch import torch
...@@ -32,6 +32,7 @@ from huggingface_hub.utils import validate_hf_hub_args ...@@ -32,6 +32,7 @@ from huggingface_hub.utils import validate_hf_hub_args
from torch import Tensor, nn from torch import Tensor, nn
from .. import __version__ from .. import __version__
from ..hooks import apply_layerwise_casting
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
from ..quantizers.quantization_config import QuantizationMethod from ..quantizers.quantization_config import QuantizationMethod
from ..utils import ( from ..utils import (
...@@ -48,6 +49,7 @@ from ..utils import ( ...@@ -48,6 +49,7 @@ from ..utils import (
is_accelerate_available, is_accelerate_available,
is_bitsandbytes_available, is_bitsandbytes_available,
is_bitsandbytes_version, is_bitsandbytes_version,
is_peft_available,
is_torch_version, is_torch_version,
logging, logging,
) )
...@@ -102,6 +104,17 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: ...@@ -102,6 +104,17 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
""" """
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
""" """
# 1. Check if we have attached any dtype modifying hooks (eg. layerwise casting)
if isinstance(parameter, nn.Module):
for name, submodule in parameter.named_modules():
if not hasattr(submodule, "_diffusers_hook"):
continue
registry = submodule._diffusers_hook
hook = registry.get_hook("layerwise_casting")
if hook is not None:
return hook.compute_dtype
# 2. If no dtype modifying hooks are attached, return the dtype of the first floating point parameter/buffer
last_dtype = None last_dtype = None
for param in parameter.parameters(): for param in parameter.parameters():
last_dtype = param.dtype last_dtype = param.dtype
...@@ -150,6 +163,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -150,6 +163,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_keys_to_ignore_on_load_unexpected = None _keys_to_ignore_on_load_unexpected = None
_no_split_modules = None _no_split_modules = None
_keep_in_fp32_modules = None _keep_in_fp32_modules = None
_skip_layerwise_casting_patterns = None
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -314,6 +328,90 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -314,6 +328,90 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
""" """
self.set_use_memory_efficient_attention_xformers(False) self.set_use_memory_efficient_attention_xformers(False)
def enable_layerwise_casting(
self,
storage_dtype: torch.dtype = torch.float8_e4m3fn,
compute_dtype: Optional[torch.dtype] = None,
skip_modules_pattern: Optional[Tuple[str, ...]] = None,
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
non_blocking: bool = False,
) -> None:
r"""
Activates layerwise casting for the current model.
Layerwise casting is a technique that casts the model weights to a lower precision dtype for storage but
upcasts them on-the-fly to a higher precision dtype for computation. This process can significantly reduce the
memory footprint from model weights, but may lead to some quality degradation in the outputs. Most degradations
are negligible, mostly stemming from weight casting in normalization and modulation layers.
By default, most models in diffusers set the `_skip_layerwise_casting_patterns` attribute to ignore patch
embedding, positional embedding and normalization layers. This is because these layers are most likely
precision-critical for quality. If you wish to change this behavior, you can set the
`_skip_layerwise_casting_patterns` attribute to `None`, or call
[`~hooks.layerwise_casting.apply_layerwise_casting`] with custom arguments.
Example:
Using [`~models.ModelMixin.enable_layerwise_casting`]:
```python
>>> from diffusers import CogVideoXTransformer3DModel
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
... )
>>> # Enable layerwise casting via the model, which ignores certain modules by default
>>> transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
```
Args:
storage_dtype (`torch.dtype`):
The dtype to which the model should be cast for storage.
compute_dtype (`torch.dtype`):
The dtype to which the model weights should be cast during the forward pass.
skip_modules_pattern (`Tuple[str, ...]`, *optional*):
A list of patterns to match the names of the modules to skip during the layerwise casting process. If
set to `None`, default skip patterns are used to ignore certain internal layers of modules and PEFT
layers.
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, *optional*):
A list of module classes to skip during the layerwise casting process.
non_blocking (`bool`, *optional*, defaults to `False`):
If `True`, the weight casting operations are non-blocking.
"""
user_provided_patterns = True
if skip_modules_pattern is None:
from ..hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
user_provided_patterns = False
if self._keep_in_fp32_modules is not None:
skip_modules_pattern += tuple(self._keep_in_fp32_modules)
if self._skip_layerwise_casting_patterns is not None:
skip_modules_pattern += tuple(self._skip_layerwise_casting_patterns)
skip_modules_pattern = tuple(set(skip_modules_pattern))
if is_peft_available() and not user_provided_patterns:
# By default, we want to skip all peft layers because they have a very low memory footprint.
# If users want to apply layerwise casting on peft layers as well, they can utilize the
# `~diffusers.hooks.layerwise_casting.apply_layerwise_casting` function which provides
# them with more flexibility and control.
from peft.tuners.loha.layer import LoHaLayer
from peft.tuners.lokr.layer import LoKrLayer
from peft.tuners.lora.layer import LoraLayer
for layer in (LoHaLayer, LoKrLayer, LoraLayer):
skip_modules_pattern += tuple(layer.adapter_layer_names)
if compute_dtype is None:
logger.info("`compute_dtype` not provided when enabling layerwise casting. Using dtype of the model.")
compute_dtype = self.dtype
apply_layerwise_casting(
self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
)
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
......
...@@ -276,6 +276,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin ...@@ -276,6 +276,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
""" """
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"] _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
@register_to_config @register_to_config
......
...@@ -212,6 +212,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -212,6 +212,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
Scaling factor to apply in 3D positional embeddings across temporal dimensions. Scaling factor to apply in 3D positional embeddings across temporal dimensions.
""" """
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"] _no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
......
...@@ -64,6 +64,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -64,6 +64,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
A small constant added to the denominator in normalization layers to prevent division by zero. A small constant added to the denominator in normalization layers to prevent division by zero.
""" """
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
@register_to_config @register_to_config
......
...@@ -244,6 +244,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): ...@@ -244,6 +244,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2 Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
""" """
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
...@@ -65,6 +65,8 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin): ...@@ -65,6 +65,8 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
The number of frames in the video-like data. The number of frames in the video-like data.
""" """
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
...@@ -221,6 +221,8 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin): ...@@ -221,6 +221,8 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
overall scale of the model's operations. overall scale of the model's operations.
""" """
_skip_layerwise_casting_patterns = ["patch_embedder", "norm", "ffn_norm"]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
...@@ -79,6 +79,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -79,6 +79,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -236,6 +236,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -236,6 +236,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"] _no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"]
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -211,6 +211,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin): ...@@ -211,6 +211,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["preprocess_conv", "postprocess_conv", "^proj_in$", "^proj_out$", "norm"]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -66,6 +66,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): ...@@ -66,6 +66,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock"] _no_split_modules = ["BasicTransformerBlock"]
_skip_layerwise_casting_patterns = ["latent_image_embedding", "norm"]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -222,6 +222,7 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin): ...@@ -222,6 +222,7 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]
@register_to_config @register_to_config
def __init__( def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment