Unverified Commit b7af9461 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

set config from original module but set compiled module on class (#3650)

* set config from original module but set compiled module on class

* add test
parent d3717e63
...@@ -485,17 +485,19 @@ class DiffusionPipeline(ConfigMixin): ...@@ -485,17 +485,19 @@ class DiffusionPipeline(ConfigMixin):
if module is None: if module is None:
register_dict = {name: (None, None)} register_dict = {name: (None, None)}
else: else:
# register the original module, not the dynamo compiled one # register the config from the original module, not the dynamo compiled one
if is_compiled_module(module): if is_compiled_module(module):
module = module._orig_mod not_compiled_module = module._orig_mod
else:
not_compiled_module = module
library = module.__module__.split(".")[0] library = not_compiled_module.__module__.split(".")[0]
# check if the module is a pipeline module # check if the module is a pipeline module
module_path_items = module.__module__.split(".") module_path_items = not_compiled_module.__module__.split(".")
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
path = module.__module__.split(".") path = not_compiled_module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module. # if library is not in LOADABLE_CLASSES, then it is a custom module.
...@@ -504,10 +506,10 @@ class DiffusionPipeline(ConfigMixin): ...@@ -504,10 +506,10 @@ class DiffusionPipeline(ConfigMixin):
if is_pipeline_module: if is_pipeline_module:
library = pipeline_dir library = pipeline_dir
elif library not in LOADABLE_CLASSES: elif library not in LOADABLE_CLASSES:
library = module.__module__ library = not_compiled_module.__module__
# retrieve class_name # retrieve class_name
class_name = module.__class__.__name__ class_name = not_compiled_module.__class__.__name__
register_dict = {name: (library, class_name)} register_dict = {name: (library, class_name)}
......
...@@ -61,6 +61,7 @@ from diffusers.utils import ( ...@@ -61,6 +61,7 @@ from diffusers.utils import (
CONFIG_NAME, CONFIG_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
floats_tensor, floats_tensor,
is_compiled_module,
nightly, nightly,
require_torch_2, require_torch_2,
slow, slow,
...@@ -99,6 +100,11 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): ...@@ -99,6 +100,11 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
scheduler = DDPMScheduler(num_train_timesteps=10) scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, scheduler) ddpm = DDPMPipeline(model, scheduler)
# previous diffusers versions stripped compilation off
# compiled modules
assert is_compiled_module(ddpm.unet)
ddpm.to(torch_device) ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None) ddpm.set_progress_bar_config(disable=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment