You need to sign in or sign up before continuing.
Unverified Commit 81125d84 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Make dynamo wrapped modules work with save_pretrained (#2726)



* Workaround for saving dynamo-wrapped models.

* Accept suggestion from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Apply workaround when overriding pipeline components.

* Ensure the correct config.json is saved to disk.

Instead of the dynamo class.

* Save correct module (not compiled one)

* Add test

* style

* fix docstrings

* Go back to using string comparisons.

PyTorch CPU does not have _dynamo.

* Simple test for save_pretrained of compiled models.

* Helper function to test whether module is compiled.

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent d4f846fa
......@@ -50,6 +50,7 @@ from ..utils import (
get_class_from_dynamic_module,
is_accelerate_available,
is_accelerate_version,
is_compiled_module,
is_safetensors_available,
is_torch_version,
is_transformers_available,
......@@ -255,7 +256,14 @@ def maybe_raise_or_warn(
if class_candidate is not None and issubclass(class_obj, class_candidate):
expected_class_obj = class_candidate
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
# Dynamo wraps the original model in a private class.
# I didn't find a public API to get the original class.
sub_model = passed_class_obj[name]
model_cls = sub_model.__class__
if is_compiled_module(sub_model):
model_cls = sub_model._orig_mod.__class__
if not issubclass(model_cls, expected_class_obj):
raise ValueError(
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
......@@ -419,6 +427,10 @@ class DiffusionPipeline(ConfigMixin):
if module is None:
register_dict = {name: (None, None)}
else:
# register the original module, not the dynamo compiled one
if is_compiled_module(module):
module = module._orig_mod
library = module.__module__.split(".")[0]
# check if the module is a pipeline module
......@@ -484,6 +496,12 @@ class DiffusionPipeline(ConfigMixin):
sub_model = getattr(self, pipeline_component_name)
model_cls = sub_model.__class__
# Dynamo wraps the original model in a private class.
# I didn't find a public API to get the original class.
if is_compiled_module(sub_model):
sub_model = sub_model._orig_mod
model_cls = sub_model.__class__
save_method_name = None
# search for the model's base class in LOADABLE_CLASSES
for library_name, library_classes in LOADABLE_CLASSES.items():
......
......@@ -74,7 +74,7 @@ from .import_utils import (
from .logging import get_logger
from .outputs import BaseOutput
from .pil_utils import PIL_INTERPOLATION
from .torch_utils import randn_tensor
from .torch_utils import is_compiled_module, randn_tensor
if is_torch_available():
......@@ -86,6 +86,7 @@ if is_torch_available():
nightly,
parse_flag_from_env,
print_tensor_test,
require_torch_2,
require_torch_gpu,
skip_mps,
slow,
......
......@@ -25,6 +25,7 @@ from .import_utils import (
is_onnx_available,
is_opencv_available,
is_torch_available,
is_torch_version,
)
from .logging import get_logger
......@@ -165,6 +166,15 @@ def require_torch(test_case):
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
def require_torch_2(test_case):
"""
Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed.
"""
return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")(
test_case
)
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
......
......@@ -17,7 +17,7 @@ PyTorch utilities: Utilities related to PyTorch
from typing import List, Optional, Tuple, Union
from . import logging
from .import_utils import is_torch_available
from .import_utils import is_torch_available, is_torch_version
if is_torch_available():
......@@ -68,3 +68,10 @@ def randn_tensor(
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
return latents
def is_compiled_module(module):
"""Check whether the module was compiled with torch.compile()"""
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
return False
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
......@@ -27,6 +27,7 @@ from requests.exceptions import HTTPError
from diffusers.models import UNet2DConditionModel
from diffusers.training_utils import EMAModel
from diffusers.utils import torch_device
from diffusers.utils.testing_utils import require_torch_gpu
class ModelUtilsTest(unittest.TestCase):
......@@ -167,6 +168,21 @@ class ModelTesterMixin:
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
@require_torch_gpu
def test_from_save_pretrained_dynamo(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model = torch.compile(model)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device)
assert new_model.__class__ == self.model_class
def test_from_save_pretrained_dtype(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
......@@ -54,7 +54,16 @@ from diffusers import (
logging,
)
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device
from diffusers.utils import (
CONFIG_NAME,
WEIGHTS_NAME,
floats_tensor,
is_flax_available,
nightly,
require_torch_2,
slow,
torch_device,
)
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu
......@@ -966,9 +975,41 @@ class PipelineSlowTests(unittest.TestCase):
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
schedular = DDPMScheduler(num_train_timesteps=10)
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname)
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
new_ddpm.to(torch_device)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
generator = torch.Generator(device=torch_device).manual_seed(0)
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@require_torch_2
def test_from_save_pretrained_dynamo(self):
# 1. Load models
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
model = torch.compile(model)
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, schedular)
ddpm = DDPMPipeline(model, scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment