Commit c9e8c519 authored by thomwolf's avatar thomwolf
Browse files

fixing SequenceSummary head in TF 2.0

parent da26bae6
......@@ -394,8 +394,8 @@ class TFSequenceSummary(tf.keras.layers.Layer):
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError
self.summary = None
if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
self.has_summary = hasattr(config, 'summary_use_proj') and config.summary_use_proj
if self.has_summary:
if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
num_classes = config.num_labels
else:
......@@ -404,16 +404,16 @@ class TFSequenceSummary(tf.keras.layers.Layer):
kernel_initializer=get_initializer(initializer_range),
name='summary')
self.activation = None
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
self.has_activation = hasattr(config, 'summary_activation') and config.summary_activation == 'tanh'
if self.has_activation:
self.activation = tf.keras.activations.tanh
self.first_dropout = None
if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
self.has_first_dropout = hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0
if self.has_first_dropout:
self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)
self.last_dropout = None
if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
self.has_last_dropout = hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0
if self.has_last_dropout:
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
def call(self, inputs, training=False):
......@@ -456,17 +456,17 @@ class TFSequenceSummary(tf.keras.layers.Layer):
elif self.summary_type == 'attn':
raise NotImplementedError
if training and self.first_dropout is not None:
output = self.first_dropout(output)
if self.has_first_dropout:
output = self.first_dropout(output, training=training)
if self.summary is not None:
if self.has_summary:
output = self.summary(output)
if self.activation is not None:
if self.has_activation:
output = self.activation(output)
if training and self.last_dropout is not None:
output = self.last_dropout(output)
if self.has_last_dropout:
output = self.last_dropout(output, training=training)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment