Unverified Commit 97d004b9 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[ip-adapter] make sure length of `scale` is same as number of ip-adapters when...


[ip-adapter] make sure length of `scale` is same as number of ip-adapters when using  `set_ip_adapter_scale` (#6884)

add
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent 76696dca
...@@ -181,11 +181,16 @@ class IPAdapterMixin: ...@@ -181,11 +181,16 @@ class IPAdapterMixin:
unet._load_ip_adapter_weights(state_dicts) unet._load_ip_adapter_weights(state_dicts)
def set_ip_adapter_scale(self, scale): def set_ip_adapter_scale(self, scale):
if not isinstance(scale, list):
scale = [scale]
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for attn_processor in unet.attn_processors.values(): for attn_processor in unet.attn_processors.values():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
if not isinstance(scale, list):
scale = [scale] * len(attn_processor.scale)
if len(attn_processor.scale) != len(scale):
raise ValueError(
f"`scale` should be a list of same length as the number if ip-adapters "
f"Expected {len(attn_processor.scale)} but got {len(scale)}."
)
attn_processor.scale = scale attn_processor.scale = scale
def unload_ip_adapter(self): def unload_ip_adapter(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment