Commit 9398058e authored by patrickvonplaten's avatar patrickvonplaten
Browse files

add easy tensor shape match test

parent 90cda45e
...@@ -923,7 +923,10 @@ class PreTrainedModel(nn.Module): ...@@ -923,7 +923,10 @@ class PreTrainedModel(nn.Module):
for layer_past in past: for layer_past in past:
# copy the relevant beam idx past to past # copy the relevant beam idx past to past
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx] reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_past.append(torch.cat(reordered_layer_past, dim=1)) reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past) past = tuple(reordered_past)
# update current length # update current length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment