Commit 0066ae22 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Creates a models folder to host all keras.Model end points.

PiperOrigin-RevId: 298905114
parent d9ae2108
......@@ -25,10 +25,8 @@ from official.modeling import tf_utils
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import configs
from official.nlp.modeling import losses
from official.nlp.modeling import models
from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_classifier
from official.nlp.modeling.networks import bert_pretrainer
from official.nlp.modeling.networks import bert_span_labeler
class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
......@@ -159,7 +157,7 @@ def pretrain_model(bert_config,
if initializer is None:
initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range)
pretrainer_model = bert_pretrainer.BertPretrainer(
pretrainer_model = models.BertPretrainer(
network=transformer_encoder,
num_classes=2, # The next sentence prediction label has two classes.
num_token_predictions=max_predictions_per_seq,
......@@ -211,7 +209,7 @@ def squad_model(bert_config,
stddev=bert_config.initializer_range)
if not hub_module_url:
bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
return bert_span_labeler.BertSpanLabeler(
return models.BertSpanLabeler(
network=bert_encoder, initializer=initializer), bert_encoder
input_word_ids = tf.keras.layers.Input(
......@@ -231,7 +229,7 @@ def squad_model(bert_config,
},
outputs=[sequence_output, pooled_output],
name='core_model')
return bert_span_labeler.BertSpanLabeler(
return models.BertSpanLabeler(
network=bert_encoder, initializer=initializer), bert_encoder
......@@ -268,7 +266,7 @@ def classifier_model(bert_config,
if not hub_module_url:
bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
return bert_classifier.BertClassifier(
return models.BertClassifier(
bert_encoder,
num_classes=num_labels,
dropout_rate=bert_config.hidden_dropout_prob,
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Models package definition."""
from official.nlp.modeling.models.bert_classifier import BertClassifier
from official.nlp.modeling.models.bert_pretrainer import BertPretrainer
from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
......@@ -22,7 +22,7 @@ import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_classifier
from official.nlp.modeling.models import bert_classifier
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
......
......@@ -22,7 +22,7 @@ import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_pretrainer
from official.nlp.modeling.models import bert_pretrainer
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
......
......@@ -22,7 +22,7 @@ import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_span_labeler
from official.nlp.modeling.models import bert_span_labeler
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment