Unverified Commit 64909f17 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

update IP-adapter code in UNetMotionModel (#6828)



fix
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent f09ca909
...@@ -792,6 +792,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -792,6 +792,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
emb = self.time_embedding(t_emb, timestep_cond) emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(repeats=num_frames, dim=0) emb = emb.repeat_interleave(repeats=num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs: if "image_embeds" not in added_cond_kwargs:
...@@ -799,10 +800,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -799,10 +800,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
) )
image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) image_embeds = self.encoder_hid_proj(image_embeds)
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds]
encoder_hidden_states = (encoder_hidden_states, image_embeds)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
# 2. pre-process # 2. pre-process
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment