Commit 0537139b authored by thomwolf's avatar thomwolf
Browse files

removing tf.function

parent 33cb00f4
...@@ -164,7 +164,6 @@ class TFBertEmbeddings(tf.keras.layers.Layer): ...@@ -164,7 +164,6 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
mean=0., stddev=self.hidden_size**-0.5)) mean=0., stddev=self.hidden_size**-0.5))
super(TFBertEmbeddings, self).build(input_shape) super(TFBertEmbeddings, self).build(input_shape)
# @tf.function
def call(self, inputs, mode="embedding", training=False): def call(self, inputs, mode="embedding", training=False):
"""Get token embeddings of inputs. """Get token embeddings of inputs.
Args: Args:
...@@ -248,7 +247,6 @@ class TFBertSelfAttention(tf.keras.layers.Layer): ...@@ -248,7 +247,6 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size)) x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3]) return tf.transpose(x, perm=[0, 2, 1, 3])
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs hidden_states, attention_mask, head_mask = inputs
...@@ -297,7 +295,6 @@ class TFBertSelfOutput(tf.keras.layers.Layer): ...@@ -297,7 +295,6 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm') self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, input_tensor = inputs hidden_states, input_tensor = inputs
...@@ -317,7 +314,6 @@ class TFBertAttention(tf.keras.layers.Layer): ...@@ -317,7 +314,6 @@ class TFBertAttention(tf.keras.layers.Layer):
def prune_heads(self, heads): def prune_heads(self, heads):
raise NotImplementedError raise NotImplementedError
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
input_tensor, attention_mask, head_mask = inputs input_tensor, attention_mask, head_mask = inputs
...@@ -336,7 +332,6 @@ class TFBertIntermediate(tf.keras.layers.Layer): ...@@ -336,7 +332,6 @@ class TFBertIntermediate(tf.keras.layers.Layer):
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
# @tf.function
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states)
...@@ -350,7 +345,6 @@ class TFBertOutput(tf.keras.layers.Layer): ...@@ -350,7 +345,6 @@ class TFBertOutput(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm') self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, input_tensor = inputs hidden_states, input_tensor = inputs
...@@ -368,7 +362,6 @@ class TFBertLayer(tf.keras.layers.Layer): ...@@ -368,7 +362,6 @@ class TFBertLayer(tf.keras.layers.Layer):
self.intermediate = TFBertIntermediate(config, name='intermediate') self.intermediate = TFBertIntermediate(config, name='intermediate')
self.bert_output = TFBertOutput(config, name='output') self.bert_output = TFBertOutput(config, name='output')
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs hidden_states, attention_mask, head_mask = inputs
...@@ -387,7 +380,6 @@ class TFBertEncoder(tf.keras.layers.Layer): ...@@ -387,7 +380,6 @@ class TFBertEncoder(tf.keras.layers.Layer):
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.layer = [TFBertLayer(config, name='layer_{}'.format(i)) for i in range(config.num_hidden_layers)] self.layer = [TFBertLayer(config, name='layer_{}'.format(i)) for i in range(config.num_hidden_layers)]
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs hidden_states, attention_mask, head_mask = inputs
...@@ -420,7 +412,6 @@ class TFBertPooler(tf.keras.layers.Layer): ...@@ -420,7 +412,6 @@ class TFBertPooler(tf.keras.layers.Layer):
super(TFBertPooler, self).__init__(**kwargs) super(TFBertPooler, self).__init__(**kwargs)
self.dense = tf.keras.layers.Dense(config.hidden_size, activation='tanh', name='dense') self.dense = tf.keras.layers.Dense(config.hidden_size, activation='tanh', name='dense')
# @tf.function
def call(self, hidden_states): def call(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding # We "pool" the model by simply taking the hidden state corresponding
# to the first token. # to the first token.
...@@ -439,7 +430,6 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer): ...@@ -439,7 +430,6 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm') self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm')
# @tf.function
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.transform_act_fn(hidden_states)
...@@ -463,7 +453,6 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer): ...@@ -463,7 +453,6 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
trainable=True, trainable=True,
name='bias') name='bias')
# @tf.function
def call(self, hidden_states): def call(self, hidden_states):
hidden_states = self.transform(hidden_states) hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) + self.bias hidden_states = self.decoder(hidden_states) + self.bias
...@@ -475,7 +464,6 @@ class TFBertMLMHead(tf.keras.layers.Layer): ...@@ -475,7 +464,6 @@ class TFBertMLMHead(tf.keras.layers.Layer):
super(TFBertMLMHead, self).__init__(**kwargs) super(TFBertMLMHead, self).__init__(**kwargs)
self.predictions = TFBertLMPredictionHead(config, name='predictions') self.predictions = TFBertLMPredictionHead(config, name='predictions')
# @tf.function
def call(self, sequence_output): def call(self, sequence_output):
prediction_scores = self.predictions(sequence_output) prediction_scores = self.predictions(sequence_output)
return prediction_scores return prediction_scores
...@@ -486,7 +474,6 @@ class TFBertNSPHead(tf.keras.layers.Layer): ...@@ -486,7 +474,6 @@ class TFBertNSPHead(tf.keras.layers.Layer):
super(TFBertNSPHead, self).__init__(**kwargs) super(TFBertNSPHead, self).__init__(**kwargs)
self.seq_relationship = tf.keras.layers.Dense(2, name='seq_relationship') self.seq_relationship = tf.keras.layers.Dense(2, name='seq_relationship')
# @tf.function
def call(self, pooled_output): def call(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output) seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score return seq_relationship_score
...@@ -511,7 +498,6 @@ class TFBertMainLayer(tf.keras.layers.Layer): ...@@ -511,7 +498,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)): if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs input_ids = inputs
...@@ -693,7 +679,6 @@ class TFBertModel(TFBertPreTrainedModel): ...@@ -693,7 +679,6 @@ class TFBertModel(TFBertPreTrainedModel):
super(TFBertModel, self).__init__(config, *inputs, **kwargs) super(TFBertModel, self).__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, training=training)
return outputs return outputs
...@@ -732,7 +717,6 @@ class TFBertForPreTraining(TFBertPreTrainedModel): ...@@ -732,7 +717,6 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp') self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, training=training)
...@@ -774,7 +758,6 @@ class TFBertForMaskedLM(TFBertPreTrainedModel): ...@@ -774,7 +758,6 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, training=training)
...@@ -818,7 +801,6 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel): ...@@ -818,7 +801,6 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
self.cls_nsp = TFBertNSPHead(config, name='cls_nsp') self.cls_nsp = TFBertNSPHead(config, name='cls_nsp')
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, training=training)
...@@ -863,7 +845,6 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel): ...@@ -863,7 +845,6 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier') self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, training=training)
...@@ -912,7 +893,6 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel): ...@@ -912,7 +893,6 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(1, name='classifier') self.classifier = tf.keras.layers.Dense(1, name='classifier')
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)): if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs input_ids = inputs
...@@ -989,7 +969,6 @@ class TFBertForTokenClassification(TFBertPreTrainedModel): ...@@ -989,7 +969,6 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier') self.classifier = tf.keras.layers.Dense(config.num_labels, name='classifier')
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, training=training)
...@@ -1040,7 +1019,6 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel): ...@@ -1040,7 +1019,6 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name='bert') self.bert = TFBertMainLayer(config, name='bert')
self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs') self.qa_outputs = tf.keras.layers.Dense(config.num_labels, name='qa_outputs')
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
outputs = self.bert(inputs, training=training) outputs = self.bert(inputs, training=training)
......
...@@ -143,7 +143,6 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -143,7 +143,6 @@ class TFAttention(tf.keras.layers.Layer):
pass pass
@staticmethod @staticmethod
# @tf.function
def causal_attention_mask(nd, ns, dtype): def causal_attention_mask(nd, ns, dtype):
"""1's in the lower triangle, counting from the lower right corner. """1's in the lower triangle, counting from the lower right corner.
Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs. Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
...@@ -153,7 +152,6 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -153,7 +152,6 @@ class TFAttention(tf.keras.layers.Layer):
m = i >= j - ns + nd m = i >= j - ns + nd
return tf.cast(m, dtype) return tf.cast(m, dtype)
# @tf.function
def _attn(self, inputs, training=False): def _attn(self, inputs, training=False):
q, k, v, attention_mask, head_mask = inputs q, k, v, attention_mask, head_mask = inputs
# q, k, v have shape [batch, heads, sequence, features] # q, k, v have shape [batch, heads, sequence, features]
...@@ -185,21 +183,18 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -185,21 +183,18 @@ class TFAttention(tf.keras.layers.Layer):
outputs.append(w) outputs.append(w)
return outputs return outputs
# @tf.function
def merge_heads(self, x): def merge_heads(self, x):
x = tf.transpose(x, [0, 2, 1, 3]) x = tf.transpose(x, [0, 2, 1, 3])
x_shape = shape_list(x) x_shape = shape_list(x)
new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]] new_x_shape = x_shape[:-2] + [x_shape[-2] * x_shape[-1]]
return tf.reshape(x, new_x_shape) return tf.reshape(x, new_x_shape)
# @tf.function
def split_heads(self, x): def split_heads(self, x):
x_shape = shape_list(x) x_shape = shape_list(x)
new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head] new_x_shape = x_shape[:-1] + [self.n_head, x_shape[-1] // self.n_head]
x = tf.reshape(x, new_x_shape) x = tf.reshape(x, new_x_shape)
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
x, layer_past, attention_mask, head_mask = inputs x, layer_past, attention_mask, head_mask = inputs
...@@ -235,7 +230,6 @@ class TFMLP(tf.keras.layers.Layer): ...@@ -235,7 +230,6 @@ class TFMLP(tf.keras.layers.Layer):
self.act = gelu self.act = gelu
self.dropout = tf.keras.layers.Dropout(config.resid_pdrop) self.dropout = tf.keras.layers.Dropout(config.resid_pdrop)
# @tf.function
def call(self, x, training=False): def call(self, x, training=False):
h = self.act(self.c_fc(x)) h = self.act(self.c_fc(x))
h2 = self.c_proj(h) h2 = self.c_proj(h)
...@@ -253,7 +247,6 @@ class TFBlock(tf.keras.layers.Layer): ...@@ -253,7 +247,6 @@ class TFBlock(tf.keras.layers.Layer):
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_2') self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_2')
self.mlp = TFMLP(4 * nx, config, name='mlp') self.mlp = TFMLP(4 * nx, config, name='mlp')
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
x, layer_past, attention_mask, head_mask = inputs x, layer_past, attention_mask, head_mask = inputs
...@@ -289,7 +282,6 @@ class TFGPT2Embeddings(tf.keras.layers.Layer): ...@@ -289,7 +282,6 @@ class TFGPT2Embeddings(tf.keras.layers.Layer):
mean=0., stddev=self.hidden_size**-0.5)) mean=0., stddev=self.hidden_size**-0.5))
super(TFGPT2Embeddings, self).build(input_shape) super(TFGPT2Embeddings, self).build(input_shape)
# @tf.function
def call(self, inputs, mode="embedding"): def call(self, inputs, mode="embedding"):
"""Get token embeddings of inputs. """Get token embeddings of inputs.
Args: Args:
...@@ -354,7 +346,6 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -354,7 +346,6 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)): if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs input_ids = inputs
...@@ -568,7 +559,6 @@ class TFGPT2Model(TFGPT2PreTrainedModel): ...@@ -568,7 +559,6 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
super(TFGPT2Model, self).__init__(config, *inputs, **kwargs) super(TFGPT2Model, self).__init__(config, *inputs, **kwargs)
self.transformer = TFGPT2MainLayer(config, name='transformer') self.transformer = TFGPT2MainLayer(config, name='transformer')
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
outputs = self.transformer(inputs, training=training) outputs = self.transformer(inputs, training=training)
return outputs return outputs
...@@ -610,7 +600,6 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel): ...@@ -610,7 +600,6 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs) super(TFGPT2LMHeadModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFGPT2MainLayer(config, name='transformer') self.transformer = TFGPT2MainLayer(config, name='transformer')
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
transformer_outputs = self.transformer(inputs, training=training) transformer_outputs = self.transformer(inputs, training=training)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
...@@ -680,7 +669,6 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -680,7 +669,6 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
self.multiple_choice_head = TFSequenceSummary(config, name='multiple_choice_head') self.multiple_choice_head = TFSequenceSummary(config, name='multiple_choice_head')
# @tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
if not isinstance(inputs, (dict, tuple, list)): if not isinstance(inputs, (dict, tuple, list)):
input_ids = inputs input_ids = inputs
......
...@@ -277,7 +277,6 @@ class TFConv1D(tf.keras.layers.Layer): ...@@ -277,7 +277,6 @@ class TFConv1D(tf.keras.layers.Layer):
shape=[1, self.nf], shape=[1, self.nf],
initializer=tf.zeros_initializer()) initializer=tf.zeros_initializer())
@tf.function
def call(self, x): def call(self, x):
bz, sl = shape_list(x)[:2] bz, sl = shape_list(x)[:2]
...@@ -334,7 +333,6 @@ class TFSequenceSummary(tf.keras.layers.Layer): ...@@ -334,7 +333,6 @@ class TFSequenceSummary(tf.keras.layers.Layer):
if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout) self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
@tf.function
def call(self, inputs, training=False): def call(self, inputs, training=False):
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer. """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
cls_index: [optional] position of the classification token if summary_type == 'cls_index', cls_index: [optional] position of the classification token if summary_type == 'cls_index',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment