Commit 2529b2d3 authored by patrickvonplaten's avatar patrickvonplaten Committed by Patrick von Platen
Browse files

set redorder past sort dimension to its default

parent 61fef6e9
......@@ -941,9 +941,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[i], 0)) for i in beam_idx]
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx]
# TODO: check whether it is an error that TF past.shape != Torch past.shape
reordered_layer_past = tf.concat(reordered_layer_past, axis=0)
reordered_layer_past = tf.concat(reordered_layer_past, axis=1)
# check that shape matches
assert shape_list(reordered_layer_past) == shape_list(layer_past)
reordered_past.append(reordered_layer_past)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment