"vscode:/vscode.git/clone" did not exist on "6db2ad1c2def699eec6f66a1928f825f079b588f"
Unverified Commit b7af9461 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

set config from original module but set compiled module on class (#3650)

* set config from original module but set compiled module on class

* add test
parent d3717e63
......@@ -485,17 +485,19 @@ class DiffusionPipeline(ConfigMixin):
if module is None:
register_dict = {name: (None, None)}
else:
# register the original module, not the dynamo compiled one
# register the config from the original module, not the dynamo compiled one
if is_compiled_module(module):
module = module._orig_mod
not_compiled_module = module._orig_mod
else:
not_compiled_module = module
library = module.__module__.split(".")[0]
library = not_compiled_module.__module__.split(".")[0]
# check if the module is a pipeline module
module_path_items = module.__module__.split(".")
module_path_items = not_compiled_module.__module__.split(".")
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
path = module.__module__.split(".")
path = not_compiled_module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
......@@ -504,10 +506,10 @@ class DiffusionPipeline(ConfigMixin):
if is_pipeline_module:
library = pipeline_dir
elif library not in LOADABLE_CLASSES:
library = module.__module__
library = not_compiled_module.__module__
# retrieve class_name
class_name = module.__class__.__name__
class_name = not_compiled_module.__class__.__name__
register_dict = {name: (library, class_name)}
......
......@@ -61,6 +61,7 @@ from diffusers.utils import (
CONFIG_NAME,
WEIGHTS_NAME,
floats_tensor,
is_compiled_module,
nightly,
require_torch_2,
slow,
......@@ -99,6 +100,11 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, scheduler)
# previous diffusers versions stripped compilation off
# compiled modules
assert is_compiled_module(ddpm.unet)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment