Commit 984be23d authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Move AlbertConfig to albert folder.

PiperOrigin-RevId: 294350927
parent 347f4044
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The ALBERT configurations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from official.nlp import bert_modeling
class AlbertConfig(bert_modeling.BertConfig):
"""Configuration for `ALBERT`."""
def __init__(self,
embedding_size,
num_hidden_groups=1,
inner_group_num=1,
**kwargs):
"""Constructs AlbertConfig.
Args:
embedding_size: Size of the factorized word embeddings.
num_hidden_groups: Number of group for the hidden layers, parameters in
the same group are shared. Note that this value and also the following
'inner_group_num' has to be 1 for now, because all released ALBERT
models set them to 1. We may support arbitary valid values in future.
inner_group_num: Number of inner repetition of attention and ffn.
**kwargs: The remaining arguments are the same as above 'BertConfig'.
"""
super(AlbertConfig, self).__init__(**kwargs)
self.embedding_size = embedding_size
# TODO(chendouble): 'inner_group_num' and 'num_hidden_groups' are always 1
# in the released ALBERT. Support other values in AlbertTransformerEncoder
# if needed.
if inner_group_num != 1 or num_hidden_groups != 1:
raise ValueError("We only support 'inner_group_num' and "
"'num_hidden_groups' as 1.")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `AlbertConfig` from a Python dictionary of parameters."""
config = AlbertConfig(embedding_size=None, vocab_size=None)
for (key, value) in six.iteritems(json_object):
config.__dict__[key] = value
return config
......@@ -23,7 +23,7 @@ from absl import flags
import tensorflow as tf
from typing import Text
from official.nlp import bert_modeling
from official.nlp.albert import configs
from official.nlp.bert import bert_models
FLAGS = flags.FLAGS
......@@ -39,7 +39,7 @@ flags.DEFINE_string(
def create_albert_model(
albert_config: bert_modeling.AlbertConfig) -> tf.keras.Model:
albert_config: configs.AlbertConfig) -> tf.keras.Model:
"""Creates an ALBERT keras core model from ALBERT configuration.
Args:
......@@ -66,7 +66,7 @@ def create_albert_model(
outputs=[pooled_output, sequence_output]), transformer_encoder
def export_albert_tfhub(albert_config: bert_modeling.AlbertConfig,
def export_albert_tfhub(albert_config: configs.AlbertConfig,
model_checkpoint_path: Text, hub_destination: Text,
sp_model_file: Text):
"""Restores a tf.keras.Model and saves for TF-Hub."""
......@@ -79,7 +79,7 @@ def export_albert_tfhub(albert_config: bert_modeling.AlbertConfig,
def main(_):
assert tf.version.VERSION.startswith('2.')
albert_config = bert_modeling.AlbertConfig.from_json_file(
albert_config = configs.AlbertConfig.from_json_file(
FLAGS.albert_config_file)
export_albert_tfhub(albert_config, FLAGS.model_checkpoint_path,
FLAGS.export_path, FLAGS.sp_model_file)
......
......@@ -24,7 +24,7 @@ import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from official.nlp import bert_modeling
from official.nlp.albert import configs
from official.nlp.albert import export_albert_tfhub
......@@ -32,7 +32,7 @@ class ExportAlbertTfhubTest(tf.test.TestCase):
def test_export_albert_tfhub(self):
# Exports a savedmodel for TF-Hub
albert_config = bert_modeling.AlbertConfig(
albert_config = configs.AlbertConfig(
vocab_size=100,
embedding_size=8,
hidden_size=16,
......
......@@ -28,7 +28,7 @@ from absl import flags
import tensorflow as tf
from official.modeling import activations
from official.nlp import bert_modeling as modeling
from official.nlp.albert import configs
from official.nlp.bert import tf1_checkpoint_converter_lib
from official.nlp.modeling import networks
......@@ -125,7 +125,7 @@ def main(_):
assert tf.version.VERSION.startswith('2.')
output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert
albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file)
albert_config = configs.AlbertConfig.from_json_file(FLAGS.albert_config_file)
convert_checkpoint(albert_config, output_path, v1_checkpoint)
......
......@@ -23,6 +23,7 @@ import tensorflow_hub as hub
from official.modeling import tf_utils
from official.nlp import bert_modeling
from official.nlp.albert import configs as albert_configs
from official.nlp.modeling import losses
from official.nlp.modeling import networks
from official.nlp.modeling.networks import bert_classifier
......@@ -109,7 +110,7 @@ def get_transformer_encoder(bert_config,
type_vocab_size=bert_config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range))
if isinstance(bert_config, bert_modeling.AlbertConfig):
if isinstance(bert_config, albert_configs.AlbertConfig):
kwargs['embedding_width'] = bert_config.embedding_size
return networks.AlbertTransformerEncoder(**kwargs)
else:
......
......@@ -31,6 +31,7 @@ import tensorflow as tf
from official.modeling import model_training_utils
from official.nlp import bert_modeling as modeling
from official.nlp import optimization
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags
from official.nlp.bert import input_pipeline
......@@ -292,7 +293,8 @@ def run_bert(strategy,
bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
else:
assert FLAGS.model_type == 'albert'
bert_config = modeling.AlbertConfig.from_json_file(FLAGS.bert_config_file)
bert_config = albert_configs.AlbertConfig.from_json_file(
FLAGS.bert_config_file)
if FLAGS.mode == 'export_only':
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must
......
......@@ -30,6 +30,7 @@ import tensorflow as tf
from official.modeling import model_training_utils
from official.nlp import bert_modeling as modeling
from official.nlp import optimization
from official.nlp.albert import configs as albert_configs
from official.nlp.bert import bert_models
from official.nlp.bert import common_flags
from official.nlp.bert import input_pipeline
......@@ -99,7 +100,7 @@ FLAGS = flags.FLAGS
MODEL_CLASSES = {
'bert': (modeling.BertConfig, squad_lib_wp, tokenization.FullTokenizer),
'albert': (modeling.AlbertConfig, squad_lib_sp,
'albert': (albert_configs.AlbertConfig, squad_lib_sp,
tokenization.FullSentencePieceTokenizer),
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment