Commit c69053e5 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

added batch_size dimension

parent 15105078
......@@ -92,9 +92,15 @@ class TestPermutation(unittest.TestCase):
add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
)
add_batch_size_dimension = lambda t: (
t.unsqueeze(0)
)
batch = tensor_tree_map(add_recycling_dims, batch)
batch = tensor_tree_map(add_batch_size_dimension, batch)
for k,v in batch.items():
print(f"{k}:{v.shape}")
with torch.no_grad():
out = model(batch)
permutated_labels = multimer_loss(out,(batch,example_label))
print(f"permuated_labels is {type(permutated_labels)} and keys are:\n {permutated_labels.keys()}")
\ No newline at end of file
print(f"finished foward on batch with batch_size dim")
# permutated_labels = multimer_loss(out,(batch,example_label))
# print(f"permuated_labels is {type(permutated_labels)} and keys are:\n {permutated_labels.keys()}")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment