Unverified Commit 8c4856cd authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] fix: animate diff lora stuff. (#8995)

* fix: animate diff lora stuff.

* fix scaling function for UNetMotionModel

* emoty
parent f240a936
......@@ -30,6 +30,7 @@ from .unet_loader_utils import _maybe_expand_lora_scales
_SET_ADAPTER_SCALE_FN_MAPPING = {
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
"SD3Transformer2DModel": lambda model_cls, weights: weights,
}
......
......@@ -19,7 +19,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, UNet2DConditionLoadersMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import logging
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
......@@ -231,7 +231,7 @@ class MotionAdapter(ModelMixin, ConfigMixin, FromOriginalModelMixin):
pass
class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
r"""
A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a
sample shaped output.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment