Commit b3aab393 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

squeeze tensors before pad_and_concat

parent 513352ae
...@@ -434,6 +434,7 @@ def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="r ...@@ -434,6 +434,7 @@ def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="r
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'" ), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors): for i, tensor in enumerate(tensors):
tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size
tensor_len = tensor.shape[0] tensor_len = tensor.shape[0]
if tensor_len < max_length: if tensor_len < max_length:
if padding_side == "right": if padding_side == "right":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment