Commit cd7cda8c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 327892795
parent 1c89b792
...@@ -28,6 +28,7 @@ from official.modeling import hyperparams ...@@ -28,6 +28,7 @@ from official.modeling import hyperparams
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
from official.nlp.projects.mobilebert import modeling
@dataclasses.dataclass @dataclasses.dataclass
...@@ -47,15 +48,72 @@ class BertEncoderConfig(hyperparams.Config): ...@@ -47,15 +48,72 @@ class BertEncoderConfig(hyperparams.Config):
embedding_size: Optional[int] = None embedding_size: Optional[int] = None
@dataclasses.dataclass
class MobileBertEncoderConfig(hyperparams.Config):
"""MobileBERT encoder configuration.
Attributes:
word_vocab_size: number of words in the vocabulary.
word_embed_size: word embedding size.
type_vocab_size: number of word types.
max_sequence_length: maximum length of input sequence.
num_blocks: number of transformer block in the encoder model.
hidden_size: the hidden size for the transformer block.
num_attention_heads: number of attention heads in the transformer block.
intermediate_size: the size of the "intermediate" (a.k.a., feed
forward) layer.
intermediate_act_fn: the non-linear activation function to apply
to the output of the intermediate/feed-forward layer.
hidden_dropout_prob: dropout probability for the hidden layers.
attention_probs_dropout_prob: dropout probability of the attention
probabilities.
intra_bottleneck_size: the size of bottleneck.
initializer_range: The stddev of the truncated_normal_initializer for
initializing all weight matrices.
key_query_shared_bottleneck: whether to share linear transformation for
keys and queries.
num_feedforward_networks: number of stacked feed-forward networks.
normalization_type: the type of normalization_type, only 'no_norm' and
'layer_norm' are supported. 'no_norm' represents the element-wise linear
transformation for the student model, as suggested by the original
MobileBERT paper. 'layer_norm' is used for the teacher model.
classifier_activation: if using the tanh activation for the final
representation of the [CLS] token in fine-tuning.
return_all_layers: if return all layer outputs.
return_attention_score: if return attention scores for each layer.
"""
word_vocab_size: int = 30522
word_embed_size: int = 128
type_vocab_size: int = 2
max_sequence_length: int = 512
num_blocks: int = 24
hidden_size: int = 512
num_attention_heads: int = 4
intermediate_size: int = 4096
intermediate_act_fn: str = "gelu"
hidden_dropout_prob: float = 0.1
attention_probs_dropout_prob: float = 0.1
intra_bottleneck_size: int = 1024
initializer_range: float = 0.02
key_query_shared_bottleneck: bool = False
num_feedforward_networks: int = 1
normalization_type: str = "layer_norm"
classifier_activation: bool = True
return_all_layers: bool = False
return_attention_score: bool = False
@dataclasses.dataclass @dataclasses.dataclass
class EncoderConfig(hyperparams.OneOfConfig): class EncoderConfig(hyperparams.OneOfConfig):
"""Encoder configuration.""" """Encoder configuration."""
type: Optional[str] = "bert" type: Optional[str] = "bert"
bert: BertEncoderConfig = BertEncoderConfig() bert: BertEncoderConfig = BertEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
ENCODER_CLS = { ENCODER_CLS = {
"bert": networks.TransformerEncoder, "bert": networks.TransformerEncoder,
"mobilebert": modeling.MobileBERTEncoder,
} }
...@@ -113,6 +171,27 @@ def build_encoder(config: EncoderConfig, ...@@ -113,6 +171,27 @@ def build_encoder(config: EncoderConfig,
stddev=encoder_cfg.initializer_range)) stddev=encoder_cfg.initializer_range))
return encoder_cls(**kwargs) return encoder_cls(**kwargs)
if encoder_type == "mobilebert":
return encoder_cls(
word_vocab_size=encoder_cfg.word_vocab_size,
word_embed_size=encoder_cfg.word_embed_size,
type_vocab_size=encoder_cfg.type_vocab_size,
max_sequence_length=encoder_cfg.max_sequence_length,
num_blocks=encoder_cfg.num_blocks,
hidden_size=encoder_cfg.hidden_size,
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
intermediate_act_fn=encoder_cfg.intermediate_act_fn,
hidden_dropout_prob=encoder_cfg.hidden_dropout_prob,
attention_probs_dropout_prob=encoder_cfg.attention_probs_dropout_prob,
intra_bottleneck_size=encoder_cfg.intra_bottleneck_size,
key_query_shared_bottleneck=encoder_cfg.key_query_shared_bottleneck,
num_feedforward_networks=encoder_cfg.num_feedforward_networks,
normalization_type=encoder_cfg.normalization_type,
classifier_activation=encoder_cfg.classifier_activation,
return_all_layers=encoder_cfg.return_all_layers,
return_attention_score=encoder_cfg.return_attention_score)
# Uses the default BERTEncoder configuration schema to create the encoder. # Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type. # If it does not match, please add a switch branch by the encoder type.
return encoder_cls( return encoder_cls(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment