"vscode:/vscode.git/clone" did not exist on "bcc6d97b69775ffa14cb37cb0cacd7b7e7a0b850"
Commit c69053e5 authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

added batch_size dimension

parent 15105078
...@@ -92,9 +92,15 @@ class TestPermutation(unittest.TestCase): ...@@ -92,9 +92,15 @@ class TestPermutation(unittest.TestCase):
add_recycling_dims = lambda t: ( add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters) t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
) )
add_batch_size_dimension = lambda t: (
t.unsqueeze(0)
)
batch = tensor_tree_map(add_recycling_dims, batch) batch = tensor_tree_map(add_recycling_dims, batch)
batch = tensor_tree_map(add_batch_size_dimension, batch)
for k,v in batch.items():
print(f"{k}:{v.shape}")
with torch.no_grad(): with torch.no_grad():
out = model(batch) out = model(batch)
permutated_labels = multimer_loss(out,(batch,example_label)) print(f"finished foward on batch with batch_size dim")
print(f"permuated_labels is {type(permutated_labels)} and keys are:\n {permutated_labels.keys()}") # permutated_labels = multimer_loss(out,(batch,example_label))
\ No newline at end of file # print(f"permuated_labels is {type(permutated_labels)} and keys are:\n {permutated_labels.keys()}")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment