Commit 00488c79 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 328197125
parent 54832af8
...@@ -43,6 +43,10 @@ flags.DEFINE_string( ...@@ -43,6 +43,10 @@ flags.DEFINE_string(
"BertModel, with no task heads.)") "BertModel, with no task heads.)")
flags.DEFINE_string("converted_checkpoint_path", None, flags.DEFINE_string("converted_checkpoint_path", None,
"Name for the created object-based V2 checkpoint.") "Name for the created object-based V2 checkpoint.")
flags.DEFINE_string("checkpoint_model_name", "encoder",
"The name of the model when saving the checkpoint, i.e., "
"the checkpoint will be saved using: "
"tf.train.Checkpoint(FLAGS.checkpoint_model_name=model).")
ALBERT_NAME_REPLACEMENTS = ( ALBERT_NAME_REPLACEMENTS = (
...@@ -94,7 +98,8 @@ def _create_albert_model(cfg): ...@@ -94,7 +98,8 @@ def _create_albert_model(cfg):
return albert_encoder return albert_encoder
def convert_checkpoint(bert_config, output_path, v1_checkpoint): def convert_checkpoint(bert_config, output_path, v1_checkpoint,
checkpoint_model_name):
"""Converts a V1 checkpoint into an OO V2 checkpoint.""" """Converts a V1 checkpoint into an OO V2 checkpoint."""
output_dir, _ = os.path.split(output_path) output_dir, _ = os.path.split(output_path)
...@@ -112,7 +117,8 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint): ...@@ -112,7 +117,8 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
# Create a V2 checkpoint from the temporary checkpoint. # Create a V2 checkpoint from the temporary checkpoint.
model = _create_albert_model(bert_config) model = _create_albert_model(bert_config)
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint, tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint,
output_path) output_path,
checkpoint_model_name)
# Clean up the temporary checkpoint, if it exists. # Clean up the temporary checkpoint, if it exists.
try: try:
...@@ -125,8 +131,10 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint): ...@@ -125,8 +131,10 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
def main(_): def main(_):
output_path = FLAGS.converted_checkpoint_path output_path = FLAGS.converted_checkpoint_path
v1_checkpoint = FLAGS.checkpoint_to_convert v1_checkpoint = FLAGS.checkpoint_to_convert
checkpoint_model_name = FLAGS.checkpoint_model_name
albert_config = configs.AlbertConfig.from_json_file(FLAGS.albert_config_file) albert_config = configs.AlbertConfig.from_json_file(FLAGS.albert_config_file)
convert_checkpoint(albert_config, output_path, v1_checkpoint) convert_checkpoint(albert_config, output_path, v1_checkpoint,
checkpoint_model_name)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -102,10 +102,28 @@ class MobileBertEncoderConfig(hyperparams.Config): ...@@ -102,10 +102,28 @@ class MobileBertEncoderConfig(hyperparams.Config):
return_attention_score: bool = False return_attention_score: bool = False
@dataclasses.dataclass
class AlbertEncoderConfig(hyperparams.Config):
"""ALBERT encoder configuration."""
vocab_size: int = 30000
embedding_width: int = 128
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
hidden_activation: str = "gelu"
intermediate_size: int = 3072
dropout_rate: float = 0.0
attention_dropout_rate: float = 0.0
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
@dataclasses.dataclass @dataclasses.dataclass
class EncoderConfig(hyperparams.OneOfConfig): class EncoderConfig(hyperparams.OneOfConfig):
"""Encoder configuration.""" """Encoder configuration."""
type: Optional[str] = "bert" type: Optional[str] = "bert"
albert: AlbertEncoderConfig = AlbertEncoderConfig()
bert: BertEncoderConfig = BertEncoderConfig() bert: BertEncoderConfig = BertEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig() mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
...@@ -113,6 +131,7 @@ class EncoderConfig(hyperparams.OneOfConfig): ...@@ -113,6 +131,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
ENCODER_CLS = { ENCODER_CLS = {
"bert": networks.TransformerEncoder, "bert": networks.TransformerEncoder,
"mobilebert": networks.MobileBERTEncoder, "mobilebert": networks.MobileBERTEncoder,
"albert": networks.AlbertTransformerEncoder,
} }
...@@ -191,6 +210,22 @@ def build_encoder(config: EncoderConfig, ...@@ -191,6 +210,22 @@ def build_encoder(config: EncoderConfig,
return_all_layers=encoder_cfg.return_all_layers, return_all_layers=encoder_cfg.return_all_layers,
return_attention_score=encoder_cfg.return_attention_score) return_attention_score=encoder_cfg.return_attention_score)
if encoder_type == "albert":
return encoder_cls(
vocab_size=encoder_cfg.vocab_size,
embedding_width=encoder_cfg.embedding_width,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range))
# Uses the default BERTEncoder configuration schema to create the encoder. # Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type. # If it does not match, please add a switch branch by the encoder type.
return encoder_cls( return encoder_cls(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment