Unverified Commit d8310a8f authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[lora] factor out the overlaps in `save_lora_weights()`. (#12027)

* factor out the overlaps in save_lora_weights().

* remove comment.

* remove comment.

* up

* fix-copies
parent 78031c29
...@@ -1064,6 +1064,41 @@ class LoraBaseMixin: ...@@ -1064,6 +1064,41 @@ class LoraBaseMixin:
save_function(state_dict, save_path) save_function(state_dict, save_path)
logger.info(f"Model weights saved in {save_path}") logger.info(f"Model weights saved in {save_path}")
@classmethod
def _save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]],
lora_metadata: Dict[str, Optional[dict]],
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
"""
Helper method to pack and save LoRA weights and metadata. This method centralizes the saving logic for all
pipeline types.
"""
state_dict = {}
final_lora_adapter_metadata = {}
for prefix, layers in lora_layers.items():
state_dict.update(cls.pack_weights(layers, prefix))
for prefix, metadata in lora_metadata.items():
if metadata:
final_lora_adapter_metadata.update(_pack_dict_with_prefix(metadata, prefix))
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=final_lora_adapter_metadata if final_lora_adapter_metadata else None,
)
@classmethod @classmethod
def _optionally_disable_offloading(cls, _pipeline): def _optionally_disable_offloading(cls, _pipeline):
return _func_optionally_disable_offloading(_pipeline=_pipeline) return _func_optionally_disable_offloading(_pipeline=_pipeline)
...@@ -510,35 +510,28 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -510,35 +510,28 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
text_encoder_lora_adapter_metadata: text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict. LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
""" """
state_dict = {} lora_layers = {}
lora_adapter_metadata = {} lora_metadata = {}
if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
if unet_lora_layers: if unet_lora_layers:
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) lora_layers[cls.unet_name] = unet_lora_layers
lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
if text_encoder_lora_layers: if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
if unet_lora_adapter_metadata: if not lora_layers:
lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name)) raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.")
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
# Save the model cls._save_lora_weights(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
def fuse_lora( def fuse_lora(
...@@ -1004,44 +997,34 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -1004,44 +997,34 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
text_encoder_2_lora_adapter_metadata: text_encoder_2_lora_adapter_metadata:
LoRA adapter metadata associated with the second text encoder to be serialized with the state dict. LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
""" """
state_dict = {} lora_layers = {}
lora_adapter_metadata = {} lora_metadata = {}
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
)
if unet_lora_layers: if unet_lora_layers:
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) lora_layers[cls.unet_name] = unet_lora_layers
lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
if text_encoder_lora_layers: if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) lora_layers["text_encoder"] = text_encoder_lora_layers
lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata
if text_encoder_2_lora_layers: if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) lora_layers["text_encoder_2"] = text_encoder_2_lora_layers
lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata
if unet_lora_adapter_metadata is not None: if not lora_layers:
lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name)) raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`."
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
if text_encoder_2_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
) )
cls.write_lora_layers( cls._save_lora_weights(
state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
def fuse_lora( def fuse_lora(
...@@ -1467,46 +1450,34 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1467,46 +1450,34 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
text_encoder_2_lora_adapter_metadata: text_encoder_2_lora_adapter_metadata:
LoRA adapter metadata associated with the second text encoder to be serialized with the state dict. LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
""" """
state_dict = {} lora_layers = {}
lora_adapter_metadata = {} lora_metadata = {}
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
)
if transformer_lora_layers: if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if text_encoder_lora_layers: if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) lora_layers["text_encoder"] = text_encoder_lora_layers
lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata
if text_encoder_2_lora_layers: if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) lora_layers["text_encoder_2"] = text_encoder_2_lora_layers
lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
if text_encoder_2_lora_adapter_metadata: if not lora_layers:
lora_adapter_metadata.update( raise ValueError(
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2") "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`."
) )
cls.write_lora_layers( cls._save_lora_weights(
state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
...@@ -1830,28 +1801,24 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): ...@@ -1830,28 +1801,24 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict. LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} lora_layers = {}
lora_adapter_metadata = {} lora_metadata = {}
if not transformer_lora_layers: if transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.") lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None: if not lora_layers:
lora_adapter_metadata.update( raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model cls._save_lora_weights(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
...@@ -2435,37 +2402,28 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2435,37 +2402,28 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
text_encoder_lora_adapter_metadata: text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict. LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
""" """
state_dict = {} lora_layers = {}
lora_adapter_metadata = {} lora_metadata = {}
if not (transformer_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
if transformer_lora_layers: if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if text_encoder_lora_layers: if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
if transformer_lora_adapter_metadata: if not lora_layers:
lora_adapter_metadata.update( raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
# Save the model cls._save_lora_weights(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
def fuse_lora( def fuse_lora(
...@@ -3254,28 +3212,24 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -3254,28 +3212,24 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict. LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} lora_layers = {}
lora_adapter_metadata = {} lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if transformer_lora_adapter_metadata is not None: if not lora_layers:
lora_adapter_metadata.update( raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model cls._save_lora_weights(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
def fuse_lora( def fuse_lora(
...@@ -3594,28 +3548,24 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -3594,28 +3548,24 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict. LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} lora_layers = {}
lora_adapter_metadata = {} lora_metadata = {}
if not transformer_lora_layers: if transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.") lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None: if not lora_layers:
lora_adapter_metadata.update( raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model cls._save_lora_weights(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
...@@ -3938,28 +3888,24 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3938,28 +3888,24 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict. LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} lora_layers = {}
lora_adapter_metadata = {} lora_metadata = {}
if not transformer_lora_layers: if transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.") lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None: if not lora_layers:
lora_adapter_metadata.update( raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model cls._save_lora_weights(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
...@@ -4280,28 +4226,24 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -4280,28 +4226,24 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict. LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} lora_layers = {}
lora_adapter_metadata = {} lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if transformer_lora_adapter_metadata is not None: if not lora_layers:
lora_adapter_metadata.update( raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model cls._save_lora_weights(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
...@@ -4624,28 +4566,24 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -4624,28 +4566,24 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict. LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} lora_layers = {}
lora_adapter_metadata = {} lora_metadata = {}
if not transformer_lora_layers: if transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.") lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None: if not lora_layers:
lora_adapter_metadata.update( raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model cls._save_lora_weights(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
...@@ -4969,28 +4907,24 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): ...@@ -4969,28 +4907,24 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict. LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} lora_layers = {}
lora_adapter_metadata = {} lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if transformer_lora_adapter_metadata is not None: if not lora_layers:
lora_adapter_metadata.update( raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model cls._save_lora_weights(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
...@@ -5384,28 +5318,24 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -5384,28 +5318,24 @@ class WanLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict. LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} lora_layers = {}
lora_adapter_metadata = {} lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if transformer_lora_adapter_metadata is not None: if not lora_layers:
lora_adapter_metadata.update( raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model cls._save_lora_weights(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
...@@ -5802,28 +5732,24 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin): ...@@ -5802,28 +5732,24 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict. LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} lora_layers = {}
lora_adapter_metadata = {} lora_metadata = {}
if not transformer_lora_layers: if transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.") lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None: if not lora_layers:
lora_adapter_metadata.update( raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model cls._save_lora_weights(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
...@@ -6144,28 +6070,24 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): ...@@ -6144,28 +6070,24 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict. LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} lora_layers = {}
lora_adapter_metadata = {} lora_metadata = {}
if not transformer_lora_layers: if transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.") lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None: if not lora_layers:
lora_adapter_metadata.update( raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model cls._save_lora_weights(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
...@@ -6488,28 +6410,24 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): ...@@ -6488,28 +6410,24 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict. LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} lora_layers = {}
lora_adapter_metadata = {} lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if transformer_lora_adapter_metadata is not None: if not lora_layers:
lora_adapter_metadata.update( raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model cls._save_lora_weights(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
...@@ -6835,28 +6753,24 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin): ...@@ -6835,28 +6753,24 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict. LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} lora_layers = {}
lora_adapter_metadata = {} lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if transformer_lora_adapter_metadata is not None: if not lora_layers:
lora_adapter_metadata.update( raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model cls._save_lora_weights(
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process, is_main_process=is_main_process,
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment