Unverified Commit c4c99c39 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[tests] Fix broken cuda, nightly and lora tests on main for CogVideoX (#10270)

fix joint pos embedding device
parent 862a7d50
......@@ -691,7 +691,7 @@ class CogVideoXPatchEmbed(nn.Module):
output_type="pt",
)
pos_embedding = pos_embedding.flatten(0, 1)
joint_pos_embedding = torch.zeros(
joint_pos_embedding = pos_embedding.new_zeros(
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
)
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment