Unverified Commit 81125d84 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Make dynamo wrapped modules work with save_pretrained (#2726)



* Workaround for saving dynamo-wrapped models.

* Accept suggestion from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Apply workaround when overriding pipeline components.

* Ensure the correct config.json is saved to disk.

Instead of the dynamo class.

* Save correct module (not compiled one)

* Add test

* style

* fix docstrings

* Go back to using string comparisons.

PyTorch CPU does not have _dynamo.

* Simple test for save_pretrained of compiled models.

* Helper function to test whether module is compiled.

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent d4f846fa
...@@ -50,6 +50,7 @@ from ..utils import ( ...@@ -50,6 +50,7 @@ from ..utils import (
get_class_from_dynamic_module, get_class_from_dynamic_module,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_compiled_module,
is_safetensors_available, is_safetensors_available,
is_torch_version, is_torch_version,
is_transformers_available, is_transformers_available,
...@@ -255,7 +256,14 @@ def maybe_raise_or_warn( ...@@ -255,7 +256,14 @@ def maybe_raise_or_warn(
if class_candidate is not None and issubclass(class_obj, class_candidate): if class_candidate is not None and issubclass(class_obj, class_candidate):
expected_class_obj = class_candidate expected_class_obj = class_candidate
if not issubclass(passed_class_obj[name].__class__, expected_class_obj): # Dynamo wraps the original model in a private class.
# I didn't find a public API to get the original class.
sub_model = passed_class_obj[name]
model_cls = sub_model.__class__
if is_compiled_module(sub_model):
model_cls = sub_model._orig_mod.__class__
if not issubclass(model_cls, expected_class_obj):
raise ValueError( raise ValueError(
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}" f" {expected_class_obj}"
...@@ -419,6 +427,10 @@ class DiffusionPipeline(ConfigMixin): ...@@ -419,6 +427,10 @@ class DiffusionPipeline(ConfigMixin):
if module is None: if module is None:
register_dict = {name: (None, None)} register_dict = {name: (None, None)}
else: else:
# register the original module, not the dynamo compiled one
if is_compiled_module(module):
module = module._orig_mod
library = module.__module__.split(".")[0] library = module.__module__.split(".")[0]
# check if the module is a pipeline module # check if the module is a pipeline module
...@@ -484,6 +496,12 @@ class DiffusionPipeline(ConfigMixin): ...@@ -484,6 +496,12 @@ class DiffusionPipeline(ConfigMixin):
sub_model = getattr(self, pipeline_component_name) sub_model = getattr(self, pipeline_component_name)
model_cls = sub_model.__class__ model_cls = sub_model.__class__
# Dynamo wraps the original model in a private class.
# I didn't find a public API to get the original class.
if is_compiled_module(sub_model):
sub_model = sub_model._orig_mod
model_cls = sub_model.__class__
save_method_name = None save_method_name = None
# search for the model's base class in LOADABLE_CLASSES # search for the model's base class in LOADABLE_CLASSES
for library_name, library_classes in LOADABLE_CLASSES.items(): for library_name, library_classes in LOADABLE_CLASSES.items():
......
...@@ -74,7 +74,7 @@ from .import_utils import ( ...@@ -74,7 +74,7 @@ from .import_utils import (
from .logging import get_logger from .logging import get_logger
from .outputs import BaseOutput from .outputs import BaseOutput
from .pil_utils import PIL_INTERPOLATION from .pil_utils import PIL_INTERPOLATION
from .torch_utils import randn_tensor from .torch_utils import is_compiled_module, randn_tensor
if is_torch_available(): if is_torch_available():
...@@ -86,6 +86,7 @@ if is_torch_available(): ...@@ -86,6 +86,7 @@ if is_torch_available():
nightly, nightly,
parse_flag_from_env, parse_flag_from_env,
print_tensor_test, print_tensor_test,
require_torch_2,
require_torch_gpu, require_torch_gpu,
skip_mps, skip_mps,
slow, slow,
......
...@@ -25,6 +25,7 @@ from .import_utils import ( ...@@ -25,6 +25,7 @@ from .import_utils import (
is_onnx_available, is_onnx_available,
is_opencv_available, is_opencv_available,
is_torch_available, is_torch_available,
is_torch_version,
) )
from .logging import get_logger from .logging import get_logger
...@@ -165,6 +166,15 @@ def require_torch(test_case): ...@@ -165,6 +166,15 @@ def require_torch(test_case):
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
def require_torch_2(test_case):
"""
Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed.
"""
return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")(
test_case
)
def require_torch_gpu(test_case): def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch.""" """Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
......
...@@ -17,7 +17,7 @@ PyTorch utilities: Utilities related to PyTorch ...@@ -17,7 +17,7 @@ PyTorch utilities: Utilities related to PyTorch
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from . import logging from . import logging
from .import_utils import is_torch_available from .import_utils import is_torch_available, is_torch_version
if is_torch_available(): if is_torch_available():
...@@ -68,3 +68,10 @@ def randn_tensor( ...@@ -68,3 +68,10 @@ def randn_tensor(
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
return latents return latents
def is_compiled_module(module):
"""Check whether the module was compiled with torch.compile()"""
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
return False
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
...@@ -27,6 +27,7 @@ from requests.exceptions import HTTPError ...@@ -27,6 +27,7 @@ from requests.exceptions import HTTPError
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import torch_device from diffusers.utils import torch_device
from diffusers.utils.testing_utils import require_torch_gpu
class ModelUtilsTest(unittest.TestCase): class ModelUtilsTest(unittest.TestCase):
...@@ -167,6 +168,21 @@ class ModelTesterMixin: ...@@ -167,6 +168,21 @@ class ModelTesterMixin:
max_diff = (image - new_image).abs().sum().item() max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
@require_torch_gpu
def test_from_save_pretrained_dynamo(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model = torch.compile(model)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device)
assert new_model.__class__ == self.model_class
def test_from_save_pretrained_dtype(self): def test_from_save_pretrained_dtype(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
...@@ -54,7 +54,16 @@ from diffusers import ( ...@@ -54,7 +54,16 @@ from diffusers import (
logging, logging,
) )
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device from diffusers.utils import (
CONFIG_NAME,
WEIGHTS_NAME,
floats_tensor,
is_flax_available,
nightly,
require_torch_2,
slow,
torch_device,
)
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu
...@@ -966,9 +975,41 @@ class PipelineSlowTests(unittest.TestCase): ...@@ -966,9 +975,41 @@ class PipelineSlowTests(unittest.TestCase):
down_block_types=("DownBlock2D", "AttnDownBlock2D"), down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"), up_block_types=("AttnUpBlock2D", "UpBlock2D"),
) )
schedular = DDPMScheduler(num_train_timesteps=10) scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname)
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
new_ddpm.to(torch_device)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
generator = torch.Generator(device=torch_device).manual_seed(0)
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@require_torch_2
def test_from_save_pretrained_dynamo(self):
# 1. Load models
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
model = torch.compile(model)
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, schedular) ddpm = DDPMPipeline(model, scheduler)
ddpm.to(torch_device) ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None) ddpm.set_progress_bar_config(disable=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment