Unverified Commit 07708793 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix encoder outputs (#8368)

parent bc0d26d1
......@@ -348,8 +348,7 @@ class TFGenerationMixin:
shape=(-1,),
)
# expand encoder_outputs
encoder_outputs = (tf.gather(encoder_outputs[0], expanded_batch_idxs, axis=0), *encoder_outputs[1:])
encoder_outputs = (tf.gather(encoder_outputs[0], expanded_batch_idxs, axis=0),)
else:
encoder_outputs = None
cur_len = shape_list(input_ids)[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment