Unverified Commit beacaa55 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] Layerwise Upcasting (#10347)



* update

* update

* make style

* remove dynamo disable

* add coauthor
Co-Authored-By: default avatarDhruv Nair <dhruv.nair@gmail.com>

* update

* update

* update

* update mixin

* add some basic tests

* update

* update

* non_blocking

* improvements

* update

* norm.* -> norm

* apply suggestions from review

* add example

* update hook implementation to the latest changes from pyramid attention broadcast

* deinitialize should raise an error

* update doc page

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update docs

* update

* refactor

* fix _always_upcast_modules for asym ae and vq_model

* fix lumina embedding forward to not depend on weight dtype

* refactor tests

* add simple lora inference tests

* _always_upcast_modules -> _precision_sensitive_module_patterns

* remove todo comments about review; revert changes to self.dtype in unets because .dtype on ModelMixin should be able to handle fp8 weight case

* check layer dtypes in lora test

* fix UNet1DModelTests::test_layerwise_upcasting_inference

* _precision_sensitive_module_patterns -> _skip_layerwise_casting_patterns based on feedback

* skip test in NCSNppModelTests

* skip tests for AutoencoderTinyTests

* skip tests for AutoencoderOobleckTests

* skip tests for UNet1DModelTests - unsupported pytorch operations

* layerwise_upcasting -> layerwise_casting

* skip tests for UNetRLModelTests; needs next pytorch release for currently unimplemented operation support

* add layerwise fp8 pipeline test

* use xfail

* Apply suggestions from code review
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* add assertion with fp32 comparison; add tolerance to fp8-fp32 vs fp32-fp32 comparison (required for a few models' test to pass)

* add note about memory consumption on tesla CI runner for failing test

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent a6476822
......@@ -166,6 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
_no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
@register_to_config
......
......@@ -262,6 +262,7 @@ class FluxTransformer2DModel(
_supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config
def __init__(
......
......@@ -542,6 +542,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
_no_split_modules = [
"HunyuanVideoTransformerBlock",
"HunyuanVideoSingleTransformerBlock",
......
......@@ -295,6 +295,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config
def __init__(
......
......@@ -336,6 +336,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
_supports_gradient_checkpointing = True
_no_split_modules = ["MochiTransformerBlock"]
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
@register_to_config
def __init__(
......
......@@ -127,6 +127,7 @@ class SD3Transformer2DModel(
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config
def __init__(
......
......@@ -67,6 +67,8 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
The maximum length of the sequence over which to apply positional embeddings.
"""
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config
def __init__(
self,
......
......@@ -71,6 +71,8 @@ class UNet1DModel(ModelMixin, ConfigMixin):
Experimental feature for using a UNet without upsampling.
"""
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config
def __init__(
self,
......@@ -223,7 +225,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
timestep_embed = self.time_proj(timesteps)
if self.config.use_timestep_embedding:
timestep_embed = self.time_mlp(timestep_embed)
timestep_embed = self.time_mlp(timestep_embed.to(sample.dtype))
else:
timestep_embed = timestep_embed[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
......
......@@ -90,6 +90,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config
def __init__(
......
......@@ -166,6 +166,7 @@ class UNet2DConditionModel(
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config
def __init__(
......
......@@ -97,6 +97,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
"""
_supports_gradient_checkpointing = False
_skip_layerwise_casting_patterns = ["norm", "time_embedding"]
@register_to_config
def __init__(
......
......@@ -1301,6 +1301,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config
def __init__(
......
......@@ -14,6 +14,7 @@
# limitations under the License.
import inspect
import os
import re
import tempfile
import unittest
from itertools import product
......@@ -2098,3 +2099,61 @@ class PeftLoraLoaderMixinTests:
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
def test_layerwise_casting_inference_denoiser(self):
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
def check_linear_dtype(module, storage_dtype, compute_dtype):
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
for name, submodule in module.named_modules():
if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check):
dtype_to_check = compute_dtype
if getattr(submodule, "weight", None) is not None:
self.assertEqual(submodule.weight.dtype, dtype_to_check)
if getattr(submodule, "bias", None) is not None:
self.assertEqual(submodule.bias.dtype, dtype_to_check)
def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
if storage_dtype is not None:
denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
check_linear_dtype(denoiser, storage_dtype, compute_dtype)
return pipe
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe_fp32 = initialize_pipeline(storage_dtype=None)
pipe_fp32(**inputs, generator=torch.manual_seed(0))[0]
pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32)
pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0]
pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
......@@ -114,6 +114,24 @@ class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
def test_set_attn_processor_for_determinism(self):
return
@unittest.skip(
"The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not "
"cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n"
"1. Make sure `nn::Module::to` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n"
"2. Unskip this test."
)
def test_layerwise_casting_inference(self):
pass
@unittest.skip(
"The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not "
"cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n"
"1. Make sure `nn::Module::to` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n"
"2. Unskip this test."
)
def test_layerwise_casting_memory(self):
pass
@slow
class AutoencoderOobleckIntegrationTests(unittest.TestCase):
......
......@@ -173,6 +173,22 @@ class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
continue
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=3e-2))
@unittest.skip(
"The forward pass of AutoencoderTiny creates a torch.float32 tensor. This causes inference in compute_dtype=torch.bfloat16 to fail. To fix:\n"
"1. Change the forward pass to be dtype agnostic.\n"
"2. Unskip this test."
)
def test_layerwise_casting_inference(self):
pass
@unittest.skip(
"The forward pass of AutoencoderTiny creates a torch.float32 tensor. This causes inference in compute_dtype=torch.bfloat16 to fail. To fix:\n"
"1. Change the forward pass to be dtype agnostic.\n"
"2. Unskip this test."
)
def test_layerwise_casting_memory(self):
pass
@slow
class AutoencoderTinyIntegrationTests(unittest.TestCase):
......
......@@ -14,9 +14,11 @@
# limitations under the License.
import copy
import gc
import inspect
import json
import os
import re
import tempfile
import traceback
import unittest
......@@ -56,9 +58,11 @@ from diffusers.utils.testing_utils import (
CaptureLogger,
get_python_version,
is_torch_compile,
numpy_cosine_similarity_distance,
require_torch_2,
require_torch_accelerator,
require_torch_accelerator_with_training,
require_torch_gpu,
require_torch_multi_gpu,
run_test_in_subprocess,
torch_all_close,
......@@ -181,6 +185,16 @@ def compute_module_persistent_sizes(
return module_sizes
def cast_maybe_tensor_dtype(maybe_tensor, current_dtype, target_dtype):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.to(target_dtype) if maybe_tensor.dtype == current_dtype else maybe_tensor
if isinstance(maybe_tensor, dict):
return {k: cast_maybe_tensor_dtype(v, current_dtype, target_dtype) for k, v in maybe_tensor.items()}
if isinstance(maybe_tensor, list):
return [cast_maybe_tensor_dtype(v, current_dtype, target_dtype) for v in maybe_tensor]
return maybe_tensor
class ModelUtilsTest(unittest.TestCase):
def tearDown(self):
super().tearDown()
......@@ -1332,6 +1346,93 @@ class ModelTesterMixin:
# Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors
assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files)
def test_layerwise_casting_inference(self):
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
def check_linear_dtype(module, storage_dtype, compute_dtype):
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
for name, submodule in module.named_modules():
if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if any(re.search(pattern, name) for pattern in patterns_to_check):
dtype_to_check = compute_dtype
if getattr(submodule, "weight", None) is not None:
self.assertEqual(submodule.weight.dtype, dtype_to_check)
if getattr(submodule, "bias", None) is not None:
self.assertEqual(submodule.bias.dtype, dtype_to_check)
def test_layerwise_casting(storage_dtype, compute_dtype):
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**config).eval()
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
check_linear_dtype(model, storage_dtype, compute_dtype)
output = model(**inputs_dict)[0].float().flatten().detach().cpu().numpy()
# The precision test is not very important for fast tests. In most cases, the outputs will not be the same.
# We just want to make sure that the layerwise casting is working as expected.
self.assertTrue(numpy_cosine_similarity_distance(base_slice, output) < 1.0)
test_layerwise_casting(torch.float16, torch.float32)
test_layerwise_casting(torch.float8_e4m3fn, torch.float32)
test_layerwise_casting(torch.float8_e5m2, torch.float32)
test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16)
@require_torch_gpu
def test_layerwise_casting_memory(self):
MB_TOLERANCE = 0.2
def reset_memory_stats():
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
def get_memory_usage(storage_dtype, compute_dtype):
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**config).eval()
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
reset_memory_stats()
model(**inputs_dict)
model_memory_footprint = model.get_memory_footprint()
peak_inference_memory_allocated_mb = torch.cuda.max_memory_allocated() / 1024**2
return model_memory_footprint, peak_inference_memory_allocated_mb
fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32)
fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
torch.float8_e4m3fn, torch.bfloat16
)
self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint)
# NOTE: the following assertion will fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes.
self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory)
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
# bytes. This only happens for some models, so we allow a small tolerance.
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
self.assertTrue(
fp8_e4m3_fp32_max_memory < fp32_max_memory
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
)
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
......
......@@ -15,6 +15,7 @@
import unittest
import pytest
import torch
from diffusers import UNet1DModel
......@@ -152,6 +153,28 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
assert (output_sum - 224.0896).abs() < 0.5
assert (output_max - 0.0607).abs() < 4e-4
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_inference(self):
super().test_layerwise_casting_inference()
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_memory(self):
pass
class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet1DModel
......@@ -274,3 +297,25 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet
pass
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_inference(self):
pass
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_memory(self):
pass
......@@ -401,3 +401,15 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def test_effective_gradient_checkpointing(self):
super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})
@unittest.skip(
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_inference(self):
pass
@unittest.skip(
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_memory(self):
pass
......@@ -57,6 +57,7 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
]
)
test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
......
......@@ -38,6 +38,7 @@ class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = AmusedPipeline
params = TEXT_TO_IMAGE_PARAMS | {"encoder_hidden_states", "negative_encoder_hidden_states"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
test_layerwise_casting = True
def get_dummy_components(self):
torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment