"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "087cbff1512f56a3d2f84477e14416edd62dabac"
Commit 9ec6f6e4 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 320641255
parent 45ab8e72
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
Includes configurations and instantiation methods. Includes configurations and instantiation methods.
""" """
import dataclasses import dataclasses
import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
...@@ -42,10 +42,43 @@ class TransformerEncoderConfig(base_config.Config): ...@@ -42,10 +42,43 @@ class TransformerEncoderConfig(base_config.Config):
initializer_range: float = 0.02 initializer_range: float = 0.02
def instantiate_encoder_from_cfg( @gin.configurable
config: TransformerEncoderConfig) -> networks.TransformerEncoder: def instantiate_encoder_from_cfg(config: TransformerEncoderConfig,
encoder_cls=networks.TransformerEncoder):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig.""" """Instantiate a Transformer encoder network from TransformerEncoderConfig."""
encoder_network = networks.TransformerEncoder( if encoder_cls.__name__ == "EncoderScaffold":
embedding_cfg = dict(
vocab_size=config.vocab_size,
type_vocab_size=config.type_vocab_size,
hidden_size=config.hidden_size,
seq_length=None,
max_seq_length=config.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
dropout_rate=config.dropout_rate,
)
hidden_cfg = dict(
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
intermediate_activation=tf_utils.get_activation(
config.hidden_activation),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range),
)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg,
num_hidden_instances=config.num_layers,
pooled_output_dim=config.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range))
return encoder_cls(**kwargs)
if encoder_cls.__name__ != "TransformerEncoder":
raise ValueError("Unknown encoder network class. %s" % str(encoder_cls))
encoder_network = encoder_cls(
vocab_size=config.vocab_size, vocab_size=config.vocab_size,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
num_layers=config.num_layers, num_layers=config.num_layers,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment