Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
cd7cda8c
Commit
cd7cda8c
authored
Aug 21, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 327892795
parent
1c89b792
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
79 additions
and
0 deletions
+79
-0
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+79
-0
No files found.
official/nlp/configs/encoders.py
View file @
cd7cda8c
...
@@ -28,6 +28,7 @@ from official.modeling import hyperparams
...
@@ -28,6 +28,7 @@ from official.modeling import hyperparams
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
networks
from
official.nlp.projects.mobilebert
import
modeling
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -47,15 +48,72 @@ class BertEncoderConfig(hyperparams.Config):
...
@@ -47,15 +48,72 @@ class BertEncoderConfig(hyperparams.Config):
embedding_size
:
Optional
[
int
]
=
None
embedding_size
:
Optional
[
int
]
=
None
@
dataclasses
.
dataclass
class
MobileBertEncoderConfig
(
hyperparams
.
Config
):
"""MobileBERT encoder configuration.
Attributes:
word_vocab_size: number of words in the vocabulary.
word_embed_size: word embedding size.
type_vocab_size: number of word types.
max_sequence_length: maximum length of input sequence.
num_blocks: number of transformer block in the encoder model.
hidden_size: the hidden size for the transformer block.
num_attention_heads: number of attention heads in the transformer block.
intermediate_size: the size of the "intermediate" (a.k.a., feed
forward) layer.
intermediate_act_fn: the non-linear activation function to apply
to the output of the intermediate/feed-forward layer.
hidden_dropout_prob: dropout probability for the hidden layers.
attention_probs_dropout_prob: dropout probability of the attention
probabilities.
intra_bottleneck_size: the size of bottleneck.
initializer_range: The stddev of the truncated_normal_initializer for
initializing all weight matrices.
key_query_shared_bottleneck: whether to share linear transformation for
keys and queries.
num_feedforward_networks: number of stacked feed-forward networks.
normalization_type: the type of normalization_type, only 'no_norm' and
'layer_norm' are supported. 'no_norm' represents the element-wise linear
transformation for the student model, as suggested by the original
MobileBERT paper. 'layer_norm' is used for the teacher model.
classifier_activation: if using the tanh activation for the final
representation of the [CLS] token in fine-tuning.
return_all_layers: if return all layer outputs.
return_attention_score: if return attention scores for each layer.
"""
word_vocab_size
:
int
=
30522
word_embed_size
:
int
=
128
type_vocab_size
:
int
=
2
max_sequence_length
:
int
=
512
num_blocks
:
int
=
24
hidden_size
:
int
=
512
num_attention_heads
:
int
=
4
intermediate_size
:
int
=
4096
intermediate_act_fn
:
str
=
"gelu"
hidden_dropout_prob
:
float
=
0.1
attention_probs_dropout_prob
:
float
=
0.1
intra_bottleneck_size
:
int
=
1024
initializer_range
:
float
=
0.02
key_query_shared_bottleneck
:
bool
=
False
num_feedforward_networks
:
int
=
1
normalization_type
:
str
=
"layer_norm"
classifier_activation
:
bool
=
True
return_all_layers
:
bool
=
False
return_attention_score
:
bool
=
False
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
EncoderConfig
(
hyperparams
.
OneOfConfig
):
class
EncoderConfig
(
hyperparams
.
OneOfConfig
):
"""Encoder configuration."""
"""Encoder configuration."""
type
:
Optional
[
str
]
=
"bert"
type
:
Optional
[
str
]
=
"bert"
bert
:
BertEncoderConfig
=
BertEncoderConfig
()
bert
:
BertEncoderConfig
=
BertEncoderConfig
()
mobilebert
:
MobileBertEncoderConfig
=
MobileBertEncoderConfig
()
ENCODER_CLS
=
{
ENCODER_CLS
=
{
"bert"
:
networks
.
TransformerEncoder
,
"bert"
:
networks
.
TransformerEncoder
,
"mobilebert"
:
modeling
.
MobileBERTEncoder
,
}
}
...
@@ -113,6 +171,27 @@ def build_encoder(config: EncoderConfig,
...
@@ -113,6 +171,27 @@ def build_encoder(config: EncoderConfig,
stddev
=
encoder_cfg
.
initializer_range
))
stddev
=
encoder_cfg
.
initializer_range
))
return
encoder_cls
(
**
kwargs
)
return
encoder_cls
(
**
kwargs
)
if
encoder_type
==
"mobilebert"
:
return
encoder_cls
(
word_vocab_size
=
encoder_cfg
.
word_vocab_size
,
word_embed_size
=
encoder_cfg
.
word_embed_size
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
max_sequence_length
=
encoder_cfg
.
max_sequence_length
,
num_blocks
=
encoder_cfg
.
num_blocks
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
intermediate_size
=
encoder_cfg
.
intermediate_size
,
intermediate_act_fn
=
encoder_cfg
.
intermediate_act_fn
,
hidden_dropout_prob
=
encoder_cfg
.
hidden_dropout_prob
,
attention_probs_dropout_prob
=
encoder_cfg
.
attention_probs_dropout_prob
,
intra_bottleneck_size
=
encoder_cfg
.
intra_bottleneck_size
,
key_query_shared_bottleneck
=
encoder_cfg
.
key_query_shared_bottleneck
,
num_feedforward_networks
=
encoder_cfg
.
num_feedforward_networks
,
normalization_type
=
encoder_cfg
.
normalization_type
,
classifier_activation
=
encoder_cfg
.
classifier_activation
,
return_all_layers
=
encoder_cfg
.
return_all_layers
,
return_attention_score
=
encoder_cfg
.
return_attention_score
)
# Uses the default BERTEncoder configuration schema to create the encoder.
# Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type.
# If it does not match, please add a switch branch by the encoder type.
return
encoder_cls
(
return
encoder_cls
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment