"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "6ca65b7b9ffbeb2d9c56b201e6d9019940a1ed2c"
Commit 0066ae22 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Creates a models folder to host all keras.Model end points.

PiperOrigin-RevId: 298905114
parent d9ae2108
...@@ -25,10 +25,8 @@ from official.modeling import tf_utils ...@@ -25,10 +25,8 @@ from official.modeling import tf_utils
from official.nlp.albert import configs as albert_configs from official.nlp.albert import configs as albert_configs
from official.nlp.bert import configs from official.nlp.bert import configs
from official.nlp.modeling import losses from official.nlp.modeling import losses
from official.nlp.modeling import models
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_classifier
from official.nlp.modeling.networks import bert_pretrainer
from official.nlp.modeling.networks import bert_span_labeler
class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer): class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
...@@ -159,7 +157,7 @@ def pretrain_model(bert_config, ...@@ -159,7 +157,7 @@ def pretrain_model(bert_config,
if initializer is None: if initializer is None:
initializer = tf.keras.initializers.TruncatedNormal( initializer = tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range) stddev=bert_config.initializer_range)
pretrainer_model = bert_pretrainer.BertPretrainer( pretrainer_model = models.BertPretrainer(
network=transformer_encoder, network=transformer_encoder,
num_classes=2, # The next sentence prediction label has two classes. num_classes=2, # The next sentence prediction label has two classes.
num_token_predictions=max_predictions_per_seq, num_token_predictions=max_predictions_per_seq,
...@@ -211,7 +209,7 @@ def squad_model(bert_config, ...@@ -211,7 +209,7 @@ def squad_model(bert_config,
stddev=bert_config.initializer_range) stddev=bert_config.initializer_range)
if not hub_module_url: if not hub_module_url:
bert_encoder = get_transformer_encoder(bert_config, max_seq_length) bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
return bert_span_labeler.BertSpanLabeler( return models.BertSpanLabeler(
network=bert_encoder, initializer=initializer), bert_encoder network=bert_encoder, initializer=initializer), bert_encoder
input_word_ids = tf.keras.layers.Input( input_word_ids = tf.keras.layers.Input(
...@@ -231,7 +229,7 @@ def squad_model(bert_config, ...@@ -231,7 +229,7 @@ def squad_model(bert_config,
}, },
outputs=[sequence_output, pooled_output], outputs=[sequence_output, pooled_output],
name='core_model') name='core_model')
return bert_span_labeler.BertSpanLabeler( return models.BertSpanLabeler(
network=bert_encoder, initializer=initializer), bert_encoder network=bert_encoder, initializer=initializer), bert_encoder
...@@ -268,7 +266,7 @@ def classifier_model(bert_config, ...@@ -268,7 +266,7 @@ def classifier_model(bert_config,
if not hub_module_url: if not hub_module_url:
bert_encoder = get_transformer_encoder(bert_config, max_seq_length) bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
return bert_classifier.BertClassifier( return models.BertClassifier(
bert_encoder, bert_encoder,
num_classes=num_labels, num_classes=num_labels,
dropout_rate=bert_config.hidden_dropout_prob, dropout_rate=bert_config.hidden_dropout_prob,
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Models package definition."""
from official.nlp.modeling.models.bert_classifier import BertClassifier
from official.nlp.modeling.models.bert_pretrainer import BertPretrainer
from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
...@@ -22,7 +22,7 @@ import tensorflow as tf ...@@ -22,7 +22,7 @@ import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_classifier from official.nlp.modeling.models import bert_classifier
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It # This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
......
...@@ -22,7 +22,7 @@ import tensorflow as tf ...@@ -22,7 +22,7 @@ import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_pretrainer from official.nlp.modeling.models import bert_pretrainer
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It # This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
......
...@@ -22,7 +22,7 @@ import tensorflow as tf ...@@ -22,7 +22,7 @@ import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_span_labeler from official.nlp.modeling.models import bert_span_labeler
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It # This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment