Unverified Commit beacaa55 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] Layerwise Upcasting (#10347)



* update

* update

* make style

* remove dynamo disable

* add coauthor
Co-Authored-By: default avatarDhruv Nair <dhruv.nair@gmail.com>

* update

* update

* update

* update mixin

* add some basic tests

* update

* update

* non_blocking

* improvements

* update

* norm.* -> norm

* apply suggestions from review

* add example

* update hook implementation to the latest changes from pyramid attention broadcast

* deinitialize should raise an error

* update doc page

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update docs

* update

* refactor

* fix _always_upcast_modules for asym ae and vq_model

* fix lumina embedding forward to not depend on weight dtype

* refactor tests

* add simple lora inference tests

* _always_upcast_modules -> _precision_sensitive_module_patterns

* remove todo comments about review; revert changes to self.dtype in unets because .dtype on ModelMixin should be able to handle fp8 weight case

* check layer dtypes in lora test

* fix UNet1DModelTests::test_layerwise_upcasting_inference

* _precision_sensitive_module_patterns -> _skip_layerwise_casting_patterns based on feedback

* skip test in NCSNppModelTests

* skip tests for AutoencoderTinyTests

* skip tests for AutoencoderOobleckTests

* skip tests for UNet1DModelTests - unsupported pytorch operations

* layerwise_upcasting -> layerwise_casting

* skip tests for UNetRLModelTests; needs next pytorch release for currently unimplemented operation support

* add layerwise fp8 pipeline test

* use xfail

* Apply suggestions from code review
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* add assertion with fp32 comparison; add tolerance to fp8-fp32 vs fp32-fp32 comparison (required for a few models' test to pass)

* add note about memory consumption on tesla CI runner for failing test

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent a6476822
...@@ -166,6 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -166,6 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
_no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"] _no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
@register_to_config @register_to_config
......
...@@ -262,6 +262,7 @@ class FluxTransformer2DModel( ...@@ -262,6 +262,7 @@ class FluxTransformer2DModel(
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -542,6 +542,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, ...@@ -542,6 +542,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
_no_split_modules = [ _no_split_modules = [
"HunyuanVideoTransformerBlock", "HunyuanVideoTransformerBlock",
"HunyuanVideoSingleTransformerBlock", "HunyuanVideoSingleTransformerBlock",
......
...@@ -295,6 +295,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin ...@@ -295,6 +295,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -336,6 +336,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri ...@@ -336,6 +336,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_no_split_modules = ["MochiTransformerBlock"] _no_split_modules = ["MochiTransformerBlock"]
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -127,6 +127,7 @@ class SD3Transformer2DModel( ...@@ -127,6 +127,7 @@ class SD3Transformer2DModel(
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -67,6 +67,8 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): ...@@ -67,6 +67,8 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
The maximum length of the sequence over which to apply positional embeddings. The maximum length of the sequence over which to apply positional embeddings.
""" """
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
......
...@@ -71,6 +71,8 @@ class UNet1DModel(ModelMixin, ConfigMixin): ...@@ -71,6 +71,8 @@ class UNet1DModel(ModelMixin, ConfigMixin):
Experimental feature for using a UNet without upsampling. Experimental feature for using a UNet without upsampling.
""" """
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -223,7 +225,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): ...@@ -223,7 +225,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
timestep_embed = self.time_proj(timesteps) timestep_embed = self.time_proj(timesteps)
if self.config.use_timestep_embedding: if self.config.use_timestep_embedding:
timestep_embed = self.time_mlp(timestep_embed) timestep_embed = self.time_mlp(timestep_embed.to(sample.dtype))
else: else:
timestep_embed = timestep_embed[..., None] timestep_embed = timestep_embed[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
......
...@@ -90,6 +90,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): ...@@ -90,6 +90,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -166,6 +166,7 @@ class UNet2DConditionModel( ...@@ -166,6 +166,7 @@ class UNet2DConditionModel(
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -97,6 +97,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -97,6 +97,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
""" """
_supports_gradient_checkpointing = False _supports_gradient_checkpointing = False
_skip_layerwise_casting_patterns = ["norm", "time_embedding"]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -1301,6 +1301,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft ...@@ -1301,6 +1301,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
@register_to_config @register_to_config
def __init__( def __init__(
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import os import os
import re
import tempfile import tempfile
import unittest import unittest
from itertools import product from itertools import product
...@@ -2098,3 +2099,61 @@ class PeftLoraLoaderMixinTests: ...@@ -2098,3 +2099,61 @@ class PeftLoraLoaderMixinTests:
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
def test_layerwise_casting_inference_denoiser(self):
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
def check_linear_dtype(module, storage_dtype, compute_dtype):
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
for name, submodule in module.named_modules():
if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check):
dtype_to_check = compute_dtype
if getattr(submodule, "weight", None) is not None:
self.assertEqual(submodule.weight.dtype, dtype_to_check)
if getattr(submodule, "bias", None) is not None:
self.assertEqual(submodule.bias.dtype, dtype_to_check)
def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
if storage_dtype is not None:
denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
check_linear_dtype(denoiser, storage_dtype, compute_dtype)
return pipe
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe_fp32 = initialize_pipeline(storage_dtype=None)
pipe_fp32(**inputs, generator=torch.manual_seed(0))[0]
pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32)
pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0]
pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
...@@ -114,6 +114,24 @@ class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa ...@@ -114,6 +114,24 @@ class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
def test_set_attn_processor_for_determinism(self): def test_set_attn_processor_for_determinism(self):
return return
@unittest.skip(
"The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not "
"cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n"
"1. Make sure `nn::Module::to` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n"
"2. Unskip this test."
)
def test_layerwise_casting_inference(self):
pass
@unittest.skip(
"The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not "
"cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n"
"1. Make sure `nn::Module::to` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n"
"2. Unskip this test."
)
def test_layerwise_casting_memory(self):
pass
@slow @slow
class AutoencoderOobleckIntegrationTests(unittest.TestCase): class AutoencoderOobleckIntegrationTests(unittest.TestCase):
......
...@@ -173,6 +173,22 @@ class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase) ...@@ -173,6 +173,22 @@ class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
continue continue
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=3e-2)) self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=3e-2))
@unittest.skip(
"The forward pass of AutoencoderTiny creates a torch.float32 tensor. This causes inference in compute_dtype=torch.bfloat16 to fail. To fix:\n"
"1. Change the forward pass to be dtype agnostic.\n"
"2. Unskip this test."
)
def test_layerwise_casting_inference(self):
pass
@unittest.skip(
"The forward pass of AutoencoderTiny creates a torch.float32 tensor. This causes inference in compute_dtype=torch.bfloat16 to fail. To fix:\n"
"1. Change the forward pass to be dtype agnostic.\n"
"2. Unskip this test."
)
def test_layerwise_casting_memory(self):
pass
@slow @slow
class AutoencoderTinyIntegrationTests(unittest.TestCase): class AutoencoderTinyIntegrationTests(unittest.TestCase):
......
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
# limitations under the License. # limitations under the License.
import copy import copy
import gc
import inspect import inspect
import json import json
import os import os
import re
import tempfile import tempfile
import traceback import traceback
import unittest import unittest
...@@ -56,9 +58,11 @@ from diffusers.utils.testing_utils import ( ...@@ -56,9 +58,11 @@ from diffusers.utils.testing_utils import (
CaptureLogger, CaptureLogger,
get_python_version, get_python_version,
is_torch_compile, is_torch_compile,
numpy_cosine_similarity_distance,
require_torch_2, require_torch_2,
require_torch_accelerator, require_torch_accelerator,
require_torch_accelerator_with_training, require_torch_accelerator_with_training,
require_torch_gpu,
require_torch_multi_gpu, require_torch_multi_gpu,
run_test_in_subprocess, run_test_in_subprocess,
torch_all_close, torch_all_close,
...@@ -181,6 +185,16 @@ def compute_module_persistent_sizes( ...@@ -181,6 +185,16 @@ def compute_module_persistent_sizes(
return module_sizes return module_sizes
def cast_maybe_tensor_dtype(maybe_tensor, current_dtype, target_dtype):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.to(target_dtype) if maybe_tensor.dtype == current_dtype else maybe_tensor
if isinstance(maybe_tensor, dict):
return {k: cast_maybe_tensor_dtype(v, current_dtype, target_dtype) for k, v in maybe_tensor.items()}
if isinstance(maybe_tensor, list):
return [cast_maybe_tensor_dtype(v, current_dtype, target_dtype) for v in maybe_tensor]
return maybe_tensor
class ModelUtilsTest(unittest.TestCase): class ModelUtilsTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
...@@ -1332,6 +1346,93 @@ class ModelTesterMixin: ...@@ -1332,6 +1346,93 @@ class ModelTesterMixin:
# Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors # Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors
assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files) assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files)
def test_layerwise_casting_inference(self):
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
def check_linear_dtype(module, storage_dtype, compute_dtype):
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
for name, submodule in module.named_modules():
if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if any(re.search(pattern, name) for pattern in patterns_to_check):
dtype_to_check = compute_dtype
if getattr(submodule, "weight", None) is not None:
self.assertEqual(submodule.weight.dtype, dtype_to_check)
if getattr(submodule, "bias", None) is not None:
self.assertEqual(submodule.bias.dtype, dtype_to_check)
def test_layerwise_casting(storage_dtype, compute_dtype):
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**config).eval()
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
check_linear_dtype(model, storage_dtype, compute_dtype)
output = model(**inputs_dict)[0].float().flatten().detach().cpu().numpy()
# The precision test is not very important for fast tests. In most cases, the outputs will not be the same.
# We just want to make sure that the layerwise casting is working as expected.
self.assertTrue(numpy_cosine_similarity_distance(base_slice, output) < 1.0)
test_layerwise_casting(torch.float16, torch.float32)
test_layerwise_casting(torch.float8_e4m3fn, torch.float32)
test_layerwise_casting(torch.float8_e5m2, torch.float32)
test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16)
@require_torch_gpu
def test_layerwise_casting_memory(self):
MB_TOLERANCE = 0.2
def reset_memory_stats():
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
def get_memory_usage(storage_dtype, compute_dtype):
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**config).eval()
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
reset_memory_stats()
model(**inputs_dict)
model_memory_footprint = model.get_memory_footprint()
peak_inference_memory_allocated_mb = torch.cuda.max_memory_allocated() / 1024**2
return model_memory_footprint, peak_inference_memory_allocated_mb
fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32)
fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
torch.float8_e4m3fn, torch.bfloat16
)
self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint)
# NOTE: the following assertion will fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes.
self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory)
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
# bytes. This only happens for some models, so we allow a small tolerance.
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
self.assertTrue(
fp8_e4m3_fp32_max_memory < fp32_max_memory
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
)
@is_staging_test @is_staging_test
class ModelPushToHubTester(unittest.TestCase): class ModelPushToHubTester(unittest.TestCase):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import pytest
import torch import torch
from diffusers import UNet1DModel from diffusers import UNet1DModel
...@@ -152,6 +153,28 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -152,6 +153,28 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
assert (output_sum - 224.0896).abs() < 0.5 assert (output_sum - 224.0896).abs() < 0.5
assert (output_max - 0.0607).abs() < 4e-4 assert (output_max - 0.0607).abs() < 4e-4
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_inference(self):
super().test_layerwise_casting_inference()
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_memory(self):
pass
class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet1DModel model_class = UNet1DModel
...@@ -274,3 +297,25 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -274,3 +297,25 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def test_forward_with_norm_groups(self): def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet # Not implemented yet for this UNet
pass pass
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_inference(self):
pass
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_memory(self):
pass
...@@ -401,3 +401,15 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -401,3 +401,15 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def test_effective_gradient_checkpointing(self): def test_effective_gradient_checkpointing(self):
super().test_effective_gradient_checkpointing(skip={"time_proj.weight"}) super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})
@unittest.skip(
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_inference(self):
pass
@unittest.skip(
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_memory(self):
pass
...@@ -57,6 +57,7 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -57,6 +57,7 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
] ]
) )
test_xformers_attention = False test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -38,6 +38,7 @@ class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -38,6 +38,7 @@ class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = AmusedPipeline pipeline_class = AmusedPipeline
params = TEXT_TO_IMAGE_PARAMS | {"encoder_hidden_states", "negative_encoder_hidden_states"} params = TEXT_TO_IMAGE_PARAMS | {"encoder_hidden_states", "negative_encoder_hidden_states"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
test_layerwise_casting = True
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment