Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
dd0126f9
Commit
dd0126f9
authored
Aug 27, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Aug 27, 2020
Browse files
Internal change
PiperOrigin-RevId: 328798031
parent
20e2cb97
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
2 deletions
+5
-2
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+5
-2
No files found.
official/nlp/configs/encoders.py
View file @
dd0126f9
...
@@ -45,6 +45,7 @@ class BertEncoderConfig(hyperparams.Config):
...
@@ -45,6 +45,7 @@ class BertEncoderConfig(hyperparams.Config):
type_vocab_size
:
int
=
2
type_vocab_size
:
int
=
2
initializer_range
:
float
=
0.02
initializer_range
:
float
=
0.02
embedding_size
:
Optional
[
int
]
=
None
embedding_size
:
Optional
[
int
]
=
None
return_all_encoder_outputs
:
bool
=
False
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -186,7 +187,8 @@ def build_encoder(config: EncoderConfig,
...
@@ -186,7 +187,8 @@ def build_encoder(config: EncoderConfig,
num_hidden_instances
=
encoder_cfg
.
num_layers
,
num_hidden_instances
=
encoder_cfg
.
num_layers
,
pooled_output_dim
=
encoder_cfg
.
hidden_size
,
pooled_output_dim
=
encoder_cfg
.
hidden_size
,
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
))
stddev
=
encoder_cfg
.
initializer_range
),
return_all_layer_outputs
=
encoder_cfg
.
return_all_encoder_outputs
)
return
encoder_cls
(
**
kwargs
)
return
encoder_cls
(
**
kwargs
)
if
encoder_type
==
"mobilebert"
:
if
encoder_type
==
"mobilebert"
:
...
@@ -242,4 +244,5 @@ def build_encoder(config: EncoderConfig,
...
@@ -242,4 +244,5 @@ def build_encoder(config: EncoderConfig,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
stddev
=
encoder_cfg
.
initializer_range
),
embedding_width
=
encoder_cfg
.
embedding_size
,
embedding_width
=
encoder_cfg
.
embedding_size
,
embedding_layer
=
embedding_layer
)
embedding_layer
=
embedding_layer
,
return_all_encoder_outputs
=
encoder_cfg
.
return_all_encoder_outputs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment