Unverified Commit ab428207 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Refactor CogVideoX transformer forward (#10789)

update
parent 8d081de8
...@@ -503,14 +503,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac ...@@ -503,14 +503,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
attention_kwargs=attention_kwargs, attention_kwargs=attention_kwargs,
) )
if not self.config.use_rotary_positional_embeddings: hidden_states = self.norm_final(hidden_states)
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 4. Final block # 4. Final block
hidden_states = self.norm_out(hidden_states, temb=emb) hidden_states = self.norm_out(hidden_states, temb=emb)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment