Unverified Commit 1ab8dc44 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1876 from huggingface/mean-fix

Mean does not exist in TF2
parents f0d22b63 3de31f8d
...@@ -460,7 +460,7 @@ class TFSequenceSummary(tf.keras.layers.Layer): ...@@ -460,7 +460,7 @@ class TFSequenceSummary(tf.keras.layers.Layer):
elif self.summary_type == 'first': elif self.summary_type == 'first':
output = hidden_states[:, 0] output = hidden_states[:, 0]
elif self.summary_type == 'mean': elif self.summary_type == 'mean':
output = tf.mean(hidden_states, axis=1) output = tf.reduce_mean(hidden_states, axis=1)
elif self.summary_type == 'cls_index': elif self.summary_type == 'cls_index':
hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims] hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims]
if cls_index is None: if cls_index is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment