Commit c9e8c519 authored by thomwolf's avatar thomwolf
Browse files

fixing SequenceSummary head in TF 2.0

parent da26bae6
...@@ -394,26 +394,26 @@ class TFSequenceSummary(tf.keras.layers.Layer): ...@@ -394,26 +394,26 @@ class TFSequenceSummary(tf.keras.layers.Layer):
# We can probably just use the multi-head attention module of PyTorch >=1.1.0 # We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError raise NotImplementedError
self.summary = None self.has_summary = hasattr(config, 'summary_use_proj') and config.summary_use_proj
if hasattr(config, 'summary_use_proj') and config.summary_use_proj: if self.has_summary:
if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
num_classes = config.num_labels num_classes = config.num_labels
else: else:
num_classes = config.hidden_size num_classes = config.hidden_size
self.summary = tf.keras.layers.Dense(num_classes, self.summary = tf.keras.layers.Dense(num_classes,
kernel_initializer=get_initializer(initializer_range), kernel_initializer=get_initializer(initializer_range),
name='summary') name='summary')
self.activation = None self.has_activation = hasattr(config, 'summary_activation') and config.summary_activation == 'tanh'
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': if self.has_activation:
self.activation = tf.keras.activations.tanh self.activation = tf.keras.activations.tanh
self.first_dropout = None self.has_first_dropout = hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0
if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: if self.has_first_dropout:
self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout) self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)
self.last_dropout = None self.has_last_dropout = hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0
if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: if self.has_last_dropout:
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout) self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
def call(self, inputs, training=False): def call(self, inputs, training=False):
...@@ -456,17 +456,17 @@ class TFSequenceSummary(tf.keras.layers.Layer): ...@@ -456,17 +456,17 @@ class TFSequenceSummary(tf.keras.layers.Layer):
elif self.summary_type == 'attn': elif self.summary_type == 'attn':
raise NotImplementedError raise NotImplementedError
if training and self.first_dropout is not None: if self.has_first_dropout:
output = self.first_dropout(output) output = self.first_dropout(output, training=training)
if self.summary is not None: if self.has_summary:
output = self.summary(output) output = self.summary(output)
if self.activation is not None: if self.has_activation:
output = self.activation(output) output = self.activation(output)
if training and self.last_dropout is not None: if self.has_last_dropout:
output = self.last_dropout(output) output = self.last_dropout(output, training=training)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment