Unverified Commit c4c99c39 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[tests] Fix broken cuda, nightly and lora tests on main for CogVideoX (#10270)

fix joint pos embedding device
parent 862a7d50
...@@ -691,7 +691,7 @@ class CogVideoXPatchEmbed(nn.Module): ...@@ -691,7 +691,7 @@ class CogVideoXPatchEmbed(nn.Module):
output_type="pt", output_type="pt",
) )
pos_embedding = pos_embedding.flatten(0, 1) pos_embedding = pos_embedding.flatten(0, 1)
joint_pos_embedding = torch.zeros( joint_pos_embedding = pos_embedding.new_zeros(
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
) )
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding) joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment