Commit 758f9d22 authored by patil-suraj's avatar patil-suraj
Browse files

add some comments

parent 2234877e
...@@ -113,7 +113,8 @@ class DiffusionPipeline(ConfigMixin): ...@@ -113,7 +113,8 @@ class DiffusionPipeline(ConfigMixin):
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained # use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path): if not os.path.isdir(pretrained_model_name_or_path):
cached_folder = snapshot_download( cached_folder = snapshot_download(
...@@ -129,11 +130,12 @@ class DiffusionPipeline(ConfigMixin): ...@@ -129,11 +130,12 @@ class DiffusionPipeline(ConfigMixin):
config_dict = cls.get_config_dict(cached_folder) config_dict = cls.get_config_dict(cached_folder)
module = config_dict["_module"] # 2. Get class name and module candidates to load custom models
class_name_ = config_dict["_class_name"] class_name_ = config_dict["_class_name"]
module_candidate = config_dict["_module"] module_candidate = config_dict["_module"]
module_candidate_name = module_candidate.replace(".py", "") module_candidate_name = module_candidate.replace(".py", "")
# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it # if we load from explicit class, let's use it
if cls != DiffusionPipeline: if cls != DiffusionPipeline:
pipeline_class = cls pipeline_class = cls
...@@ -147,6 +149,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -147,6 +149,7 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs = {} init_kwargs = {}
# 4. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items(): for name, (library_name, class_name) in init_dict.items():
# if the model is not in diffusers or transformers, we need to load it from the hub # if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin # assumes that it's a subclass of ModelMixin
...@@ -156,6 +159,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -156,6 +159,7 @@ class DiffusionPipeline(ConfigMixin):
importable_classes = ALL_IMPORTABLE_CLASSES importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()} class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
else: else:
# else we just import it from the library.
library = importlib.import_module(library_name) library = importlib.import_module(library_name)
class_obj = getattr(library, class_name) class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name] importable_classes = LOADABLE_CLASSES[library_name]
...@@ -168,12 +172,15 @@ class DiffusionPipeline(ConfigMixin): ...@@ -168,12 +172,15 @@ class DiffusionPipeline(ConfigMixin):
load_method = getattr(class_obj, load_method_name) load_method = getattr(class_obj, load_method_name)
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)): if os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name)) loaded_sub_model = load_method(os.path.join(cached_folder, name))
else: else:
# else load from the root directory
loaded_sub_model = load_method(cached_folder) loaded_sub_model = load_method(cached_folder)
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
# 5. Instantiate the pipeline
model = pipeline_class(**init_kwargs) model = pipeline_class(**init_kwargs)
return model return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment