Unverified Commit 86a26761 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Correctly handle creating model index json files when setting compiled modules...

Correctly handle creating model index json files when setting compiled modules in pipelines.  (#6436)

update
parent 6ef2b8a9
...@@ -530,6 +530,36 @@ def load_sub_model( ...@@ -530,6 +530,36 @@ def load_sub_model(
return loaded_sub_model return loaded_sub_model
def _fetch_class_library_tuple(module):
# import it here to avoid circular import
diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")
# register the config from the original module, not the dynamo compiled one
not_compiled_module = _unwrap_model(module)
library = not_compiled_module.__module__.split(".")[0]
# check if the module is a pipeline module
module_path_items = not_compiled_module.__module__.split(".")
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
path = not_compiled_module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if is_pipeline_module:
library = pipeline_dir
elif library not in LOADABLE_CLASSES:
library = not_compiled_module.__module__
# retrieve class_name
class_name = not_compiled_module.__class__.__name__
return (library, class_name)
class DiffusionPipeline(ConfigMixin, PushToHubMixin): class DiffusionPipeline(ConfigMixin, PushToHubMixin):
r""" r"""
Base class for all pipelines. Base class for all pipelines.
...@@ -556,38 +586,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -556,38 +586,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
_is_onnx = False _is_onnx = False
def register_modules(self, **kwargs): def register_modules(self, **kwargs):
# import it here to avoid circular import
diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")
for name, module in kwargs.items(): for name, module in kwargs.items():
# retrieve library # retrieve library
if module is None or isinstance(module, (tuple, list)) and module[0] is None: if module is None or isinstance(module, (tuple, list)) and module[0] is None:
register_dict = {name: (None, None)} register_dict = {name: (None, None)}
else: else:
# register the config from the original module, not the dynamo compiled one library, class_name = _fetch_class_library_tuple(module)
not_compiled_module = _unwrap_model(module)
library = not_compiled_module.__module__.split(".")[0]
# check if the module is a pipeline module
module_path_items = not_compiled_module.__module__.split(".")
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
path = not_compiled_module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if is_pipeline_module:
library = pipeline_dir
elif library not in LOADABLE_CLASSES:
library = not_compiled_module.__module__
# retrieve class_name
class_name = not_compiled_module.__class__.__name__
register_dict = {name: (library, class_name)} register_dict = {name: (library, class_name)}
# save model index config # save model index config
...@@ -601,7 +605,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -601,7 +605,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# We need to overwrite the config if name exists in config # We need to overwrite the config if name exists in config
if isinstance(getattr(self.config, name), (tuple, list)): if isinstance(getattr(self.config, name), (tuple, list)):
if value is not None and self.config[name][0] is not None: if value is not None and self.config[name][0] is not None:
class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__) class_library_tuple = _fetch_class_library_tuple(value)
else: else:
class_library_tuple = (None, None) class_library_tuple = (None, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment