"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7d0a47f387e7c76ffa4fee5e7365228cef25801d"
Unverified Commit b7af9461 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

set config from original module but set compiled module on class (#3650)

* set config from original module but set compiled module on class

* add test
parent d3717e63
...@@ -485,17 +485,19 @@ class DiffusionPipeline(ConfigMixin): ...@@ -485,17 +485,19 @@ class DiffusionPipeline(ConfigMixin):
if module is None: if module is None:
register_dict = {name: (None, None)} register_dict = {name: (None, None)}
else: else:
# register the original module, not the dynamo compiled one # register the config from the original module, not the dynamo compiled one
if is_compiled_module(module): if is_compiled_module(module):
module = module._orig_mod not_compiled_module = module._orig_mod
else:
not_compiled_module = module
library = module.__module__.split(".")[0] library = not_compiled_module.__module__.split(".")[0]
# check if the module is a pipeline module # check if the module is a pipeline module
module_path_items = module.__module__.split(".") module_path_items = not_compiled_module.__module__.split(".")
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
path = module.__module__.split(".") path = not_compiled_module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module. # if library is not in LOADABLE_CLASSES, then it is a custom module.
...@@ -504,10 +506,10 @@ class DiffusionPipeline(ConfigMixin): ...@@ -504,10 +506,10 @@ class DiffusionPipeline(ConfigMixin):
if is_pipeline_module: if is_pipeline_module:
library = pipeline_dir library = pipeline_dir
elif library not in LOADABLE_CLASSES: elif library not in LOADABLE_CLASSES:
library = module.__module__ library = not_compiled_module.__module__
# retrieve class_name # retrieve class_name
class_name = module.__class__.__name__ class_name = not_compiled_module.__class__.__name__
register_dict = {name: (library, class_name)} register_dict = {name: (library, class_name)}
......
...@@ -61,6 +61,7 @@ from diffusers.utils import ( ...@@ -61,6 +61,7 @@ from diffusers.utils import (
CONFIG_NAME, CONFIG_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
floats_tensor, floats_tensor,
is_compiled_module,
nightly, nightly,
require_torch_2, require_torch_2,
slow, slow,
...@@ -99,6 +100,11 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): ...@@ -99,6 +100,11 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
scheduler = DDPMScheduler(num_train_timesteps=10) scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, scheduler) ddpm = DDPMPipeline(model, scheduler)
# previous diffusers versions stripped compilation off
# compiled modules
assert is_compiled_module(ddpm.unet)
ddpm.to(torch_device) ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None) ddpm.set_progress_bar_config(disable=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment