Unverified Commit ab428207 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Refactor CogVideoX transformer forward (#10789)

update
parent 8d081de8
......@@ -503,14 +503,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
attention_kwargs=attention_kwargs,
)
if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = self.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
hidden_states = self.norm_final(hidden_states)
# 4. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment