Commit 46ca66de authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

fix squeeze() in pad_and_concat

parent 0b896bd2
......@@ -468,7 +468,8 @@ def pad_and_concat(
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors):
tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size
if len(tensor.shape) == 2:
tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size
tensor_len = tensor.shape[0]
if tensor_len < max_length:
if padding_side == "right":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment