"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "88637044bee2e95a96c3f0a3a1e9022ace330ea9"
Unverified Commit 0bae6e44 authored by Haofan Wang's avatar Haofan Wang Committed by GitHub
Browse files

Allow from_transformer in SD3ControlNetModel (#8749)



* Update controlnet_sd3.py

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 0368483b
......@@ -239,16 +239,16 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
module.gradient_checkpointing = value
@classmethod
def from_transformer(cls, transformer, num_layers=None, load_weights_from_transformer=True):
def from_transformer(cls, transformer, num_layers=12, load_weights_from_transformer=True):
config = transformer.config
config["num_layers"] = num_layers or config.num_layers
controlnet = cls(**config)
if load_weights_from_transformer:
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=False)
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict(), strict=False)
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict(), strict=False)
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict())
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment