".github/vscode:/vscode.git/clone" did not exist on "627d44465ea30f66bb9c62d91230043607fdb811"
Commit 46ca66de authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

fix squeeze() in pad_and_concat

parent 0b896bd2
...@@ -468,7 +468,8 @@ def pad_and_concat( ...@@ -468,7 +468,8 @@ def pad_and_concat(
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'" ), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors): for i, tensor in enumerate(tensors):
tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size if len(tensor.shape) == 2:
tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size
tensor_len = tensor.shape[0] tensor_len = tensor.shape[0]
if tensor_len < max_length: if tensor_len < max_length:
if padding_side == "right": if padding_side == "right":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment