Unverified Commit f28a8c25 authored by captainzz's avatar captainzz Committed by GitHub
Browse files

fix from_transformer() with extra conditioning channels (#9364)



* fix from_transformer() with extra conditioning channels

* style fix

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarÁlvaro Somoza <somoza.alvaro@gmail.com>
parent 2c6a6c97
...@@ -242,9 +242,12 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -242,9 +242,12 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
module.gradient_checkpointing = value module.gradient_checkpointing = value
@classmethod @classmethod
def from_transformer(cls, transformer, num_layers=12, load_weights_from_transformer=True): def from_transformer(
cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True
):
config = transformer.config config = transformer.config
config["num_layers"] = num_layers or config.num_layers config["num_layers"] = num_layers or config.num_layers
config["extra_conditioning_channels"] = num_extra_conditioning_channels
controlnet = cls(**config) controlnet = cls(**config)
if load_weights_from_transformer: if load_weights_from_transformer:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment