Unverified Commit b7af9461 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

set config from original module but set compiled module on class (#3650)

* set config from original module but set compiled module on class

* add test
parent d3717e63
......@@ -485,17 +485,19 @@ class DiffusionPipeline(ConfigMixin):
if module is None:
register_dict = {name: (None, None)}
else:
# register the original module, not the dynamo compiled one
# register the config from the original module, not the dynamo compiled one
if is_compiled_module(module):
module = module._orig_mod
not_compiled_module = module._orig_mod
else:
not_compiled_module = module
library = module.__module__.split(".")[0]
library = not_compiled_module.__module__.split(".")[0]
# check if the module is a pipeline module
module_path_items = module.__module__.split(".")
module_path_items = not_compiled_module.__module__.split(".")
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
path = module.__module__.split(".")
path = not_compiled_module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
......@@ -504,10 +506,10 @@ class DiffusionPipeline(ConfigMixin):
if is_pipeline_module:
library = pipeline_dir
elif library not in LOADABLE_CLASSES:
library = module.__module__
library = not_compiled_module.__module__
# retrieve class_name
class_name = module.__class__.__name__
class_name = not_compiled_module.__class__.__name__
register_dict = {name: (library, class_name)}
......
......@@ -61,6 +61,7 @@ from diffusers.utils import (
CONFIG_NAME,
WEIGHTS_NAME,
floats_tensor,
is_compiled_module,
nightly,
require_torch_2,
slow,
......@@ -99,6 +100,11 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, scheduler)
# previous diffusers versions stripped compilation off
# compiled modules
assert is_compiled_module(ddpm.unet)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment