Unverified Commit b4be4228 authored by Teriks's avatar Teriks Committed by GitHub
Browse files

Kolors additional pipelines, community contrib (#11372)



* Kolors additional pipelines, community contrib

---------
Co-authored-by: default avatarTeriks <Teriks@users.noreply.github.com>
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent a4f9c3cb
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -433,7 +433,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
if not is_sparse:
# down_weight is copied to each split
ait_sd.update({k: down_weight for k in ait_down_keys})
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
......@@ -923,7 +923,7 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
# down_weight is copied to each split
ait_sd.update({k: down_weight for k in ait_down_keys})
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
......
......@@ -469,7 +469,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
class_obj = import_flax_or_no_model(pipeline_module, class_name)
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in importable_classes.keys()}
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
......
......@@ -341,13 +341,13 @@ def get_class_obj_and_candidates(
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
class_candidates = {c: class_obj for c in importable_classes.keys()}
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
# load custom component
class_obj = get_class_from_dynamic_module(
component_folder, module_file=library_name + ".py", class_name=class_name
)
class_candidates = {c: class_obj for c in importable_classes.keys()}
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
......
......@@ -205,7 +205,7 @@ class StableDiffusionDiffEditPipelineFastTests(
# set all optional components to None and update pipeline config accordingly
for optional_component in pipe._optional_components:
setattr(pipe, optional_component, None)
pipe.register_modules(**{optional_component: None for optional_component in pipe._optional_components})
pipe.register_modules(**dict.fromkeys(pipe._optional_components))
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment