Unverified Commit 6b586a91 authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Merged commit includes the following changes: (#7370)

261393597  by hongkuny<hongkuny@google.com>:

    add an encoder mode for BertModel which returns all layers.

--

PiperOrigin-RevId: 261393597
parent 23c0017f
...@@ -193,8 +193,17 @@ class BertModel(tf.keras.layers.Layer): ...@@ -193,8 +193,17 @@ class BertModel(tf.keras.layers.Layer):
inputs = pack_inputs([input_word_ids, input_mask, input_type_ids]) inputs = pack_inputs([input_word_ids, input_mask, input_type_ids])
return super(BertModel, self).__call__(inputs, **kwargs) return super(BertModel, self).__call__(inputs, **kwargs)
def call(self, inputs): def call(self, inputs, mode="bert"):
"""Implements call() for the layer.""" """Implements call() for the layer.
Args:
inputs: packed input tensors.
mode: string, `bert` or `encoder`.
Returns:
Output tensor of the last layer for BERT training (mode=`bert`) which
is a float Tensor of shape [batch_size, seq_length, hidden_size] or
a list of output tensors for encoder usage (mode=`encoder`).
"""
unpacked_inputs = unpack_inputs(inputs) unpacked_inputs = unpack_inputs(inputs)
input_word_ids = unpacked_inputs[0] input_word_ids = unpacked_inputs[0]
input_mask = unpacked_inputs[1] input_mask = unpacked_inputs[1]
...@@ -209,10 +218,13 @@ class BertModel(tf.keras.layers.Layer): ...@@ -209,10 +218,13 @@ class BertModel(tf.keras.layers.Layer):
if input_mask is not None: if input_mask is not None:
attention_mask = create_attention_mask_from_input_mask( attention_mask = create_attention_mask_from_input_mask(
input_word_ids, input_mask) input_word_ids, input_mask)
sequence_output = self.encoder(embedding_tensor, attention_mask)
first_token_tensor = tf.squeeze(sequence_output[:, 0:1, :], axis=1) if mode == "encoder":
return self.encoder(
embedding_tensor, attention_mask, return_all_layers=True)
sequence_output = self.encoder(embedding_tensor, attention_mask)
first_token_tensor = tf.squeeze(sequence_output[:, 0:1, :], axis=1)
pooled_output = self.pooler_transform(first_token_tensor) pooled_output = self.pooler_transform(first_token_tensor)
return (pooled_output, sequence_output) return (pooled_output, sequence_output)
...@@ -803,16 +815,30 @@ class Transformer(tf.keras.layers.Layer): ...@@ -803,16 +815,30 @@ class Transformer(tf.keras.layers.Layer):
inputs = pack_inputs([input_tensor, attention_mask]) inputs = pack_inputs([input_tensor, attention_mask])
return super(Transformer, self).__call__(inputs=inputs, **kwargs) return super(Transformer, self).__call__(inputs=inputs, **kwargs)
def call(self, inputs): def call(self, inputs, return_all_layers=False):
"""Implements call() for the layer.""" """Implements call() for the layer.
Args:
inputs: packed inputs.
return_all_layers: bool, whether to return outputs of all layers inside
encoders.
Returns:
Output tensor of the last layer or a list of output tensors.
"""
unpacked_inputs = unpack_inputs(inputs) unpacked_inputs = unpack_inputs(inputs)
input_tensor = unpacked_inputs[0] input_tensor = unpacked_inputs[0]
attention_mask = unpacked_inputs[1] attention_mask = unpacked_inputs[1]
output_tensor = input_tensor output_tensor = input_tensor
all_layer_outputs = []
for layer in self.layers: for layer in self.layers:
output_tensor = layer(output_tensor, attention_mask) output_tensor = layer(output_tensor, attention_mask)
return output_tensor all_layer_outputs.append(output_tensor)
if return_all_layers:
return all_layer_outputs
return all_layer_outputs[-1]
def pack_inputs(inputs): def pack_inputs(inputs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment