Unverified Commit 75d7e5cc authored by hlky's avatar hlky Committed by GitHub
Browse files

Fix LatteTransformer3DModel dtype mismatch with enable_temporal_attentions (#11139)

parent 617c208b
......@@ -273,7 +273,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
if i == 0 and num_frame > 1:
hidden_states = hidden_states + self.temp_pos_embed
hidden_states = hidden_states + self.temp_pos_embed.to(hidden_states.dtype)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment