"tests/vscode:/vscode.git/clone" did not exist on "30cef6bff344708734bb8173e19646c6a2d979b4"
Commit 4a086ad5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

move bert_models.py into the bert folder.

PiperOrigin-RevId: 293415385
parent 6d9256ce
...@@ -30,38 +30,6 @@ from official.nlp.modeling.networks import bert_pretrainer ...@@ -30,38 +30,6 @@ from official.nlp.modeling.networks import bert_pretrainer
from official.nlp.modeling.networks import bert_span_labeler from official.nlp.modeling.networks import bert_span_labeler
def gather_indexes(sequence_tensor, positions):
"""Gathers the vectors at the specific positions.
Args:
sequence_tensor: Sequence output of `BertModel` layer of shape
(`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
hidden units of `BertModel` layer.
positions: Positions ids of tokens in sequence to mask for pretraining of
with dimension (batch_size, max_predictions_per_seq) where
`max_predictions_per_seq` is maximum number of tokens to mask out and
predict per each sequence.
Returns:
Masked out sequence tensor of shape (batch_size * max_predictions_per_seq,
num_hidden).
"""
sequence_shape = tf_utils.get_shape_list(
sequence_tensor, name='sequence_output_tensor')
batch_size = sequence_shape[0]
seq_length = sequence_shape[1]
width = sequence_shape[2]
flat_offsets = tf.keras.backend.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.keras.backend.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.keras.backend.reshape(
sequence_tensor, [batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
"""Returns layer that computes custom loss and metrics for pretraining.""" """Returns layer that computes custom loss and metrics for pretraining."""
......
...@@ -24,7 +24,7 @@ import tensorflow as tf ...@@ -24,7 +24,7 @@ import tensorflow as tf
from typing import Optional, Text from typing import Optional, Text
from official.nlp import bert_modeling from official.nlp import bert_modeling
from official.nlp import bert_models from official.nlp.bert import bert_models
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -30,8 +30,8 @@ import tensorflow as tf ...@@ -30,8 +30,8 @@ import tensorflow as tf
# pylint: disable=g-import-not-at-top,redefined-outer-name,reimported # pylint: disable=g-import-not-at-top,redefined-outer-name,reimported
from official.modeling import model_training_utils from official.modeling import model_training_utils
from official.nlp import bert_modeling as modeling from official.nlp import bert_modeling as modeling
from official.nlp import bert_models
from official.nlp import optimization from official.nlp import optimization
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags from official.nlp.bert import common_flags
from official.nlp.bert import input_pipeline from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils from official.nlp.bert import model_saving_utils
......
...@@ -25,8 +25,8 @@ import tensorflow as tf ...@@ -25,8 +25,8 @@ import tensorflow as tf
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported # pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
from official.modeling import model_training_utils from official.modeling import model_training_utils
from official.nlp import bert_modeling as modeling from official.nlp import bert_modeling as modeling
from official.nlp import bert_models
from official.nlp import optimization from official.nlp import optimization
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags from official.nlp.bert import common_flags
from official.nlp.bert import input_pipeline from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils from official.nlp.bert import model_saving_utils
......
...@@ -29,8 +29,8 @@ import tensorflow as tf ...@@ -29,8 +29,8 @@ import tensorflow as tf
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported # pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
from official.modeling import model_training_utils from official.modeling import model_training_utils
from official.nlp import bert_modeling as modeling from official.nlp import bert_modeling as modeling
from official.nlp import bert_models
from official.nlp import optimization from official.nlp import optimization
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags from official.nlp.bert import common_flags
from official.nlp.bert import input_pipeline from official.nlp.bert import input_pipeline
from official.nlp.bert import model_saving_utils from official.nlp.bert import model_saving_utils
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment