Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9ec6f6e4
Commit
9ec6f6e4
authored
Jul 10, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jul 10, 2020
Browse files
Internal change
PiperOrigin-RevId: 320641255
parent
45ab8e72
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
4 deletions
+37
-4
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+37
-4
No files found.
official/nlp/configs/encoders.py
View file @
9ec6f6e4
...
...
@@ -17,8 +17,8 @@
Includes configurations and instantiation methods.
"""
import
dataclasses
import
gin
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
...
...
@@ -42,10 +42,43 @@ class TransformerEncoderConfig(base_config.Config):
initializer_range
:
float
=
0.02
def
instantiate_encoder_from_cfg
(
config
:
TransformerEncoderConfig
)
->
networks
.
TransformerEncoder
:
@
gin
.
configurable
def
instantiate_encoder_from_cfg
(
config
:
TransformerEncoderConfig
,
encoder_cls
=
networks
.
TransformerEncoder
):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
encoder_network
=
networks
.
TransformerEncoder
(
if
encoder_cls
.
__name__
==
"EncoderScaffold"
:
embedding_cfg
=
dict
(
vocab_size
=
config
.
vocab_size
,
type_vocab_size
=
config
.
type_vocab_size
,
hidden_size
=
config
.
hidden_size
,
seq_length
=
None
,
max_seq_length
=
config
.
max_position_embeddings
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
),
dropout_rate
=
config
.
dropout_rate
,
)
hidden_cfg
=
dict
(
num_attention_heads
=
config
.
num_attention_heads
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_activation
=
tf_utils
.
get_activation
(
config
.
hidden_activation
),
dropout_rate
=
config
.
dropout_rate
,
attention_dropout_rate
=
config
.
attention_dropout_rate
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
),
)
kwargs
=
dict
(
embedding_cfg
=
embedding_cfg
,
hidden_cfg
=
hidden_cfg
,
num_hidden_instances
=
config
.
num_layers
,
pooled_output_dim
=
config
.
hidden_size
,
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
))
return
encoder_cls
(
**
kwargs
)
if
encoder_cls
.
__name__
!=
"TransformerEncoder"
:
raise
ValueError
(
"Unknown encoder network class. %s"
%
str
(
encoder_cls
))
encoder_network
=
encoder_cls
(
vocab_size
=
config
.
vocab_size
,
hidden_size
=
config
.
hidden_size
,
num_layers
=
config
.
num_layers
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment