Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3dccfae1
Commit
3dccfae1
authored
May 14, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
May 14, 2020
Browse files
Internal change
PiperOrigin-RevId: 311602262
parent
8c408bbe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
5 deletions
+9
-5
official/nlp/bert/input_pipeline.py
official/nlp/bert/input_pipeline.py
+7
-4
official/nlp/optimization.py
official/nlp/optimization.py
+2
-1
No files found.
official/nlp/bert/input_pipeline.py
View file @
3dccfae1
...
@@ -63,7 +63,8 @@ def create_pretrain_dataset(input_patterns,
...
@@ -63,7 +63,8 @@ def create_pretrain_dataset(input_patterns,
is_training
=
True
,
is_training
=
True
,
input_pipeline_context
=
None
,
input_pipeline_context
=
None
,
use_next_sentence_label
=
True
,
use_next_sentence_label
=
True
,
use_position_id
=
False
):
use_position_id
=
False
,
output_fake_labels
=
True
):
"""Creates input dataset from (tf)records files for pretraining."""
"""Creates input dataset from (tf)records files for pretraining."""
name_to_features
=
{
name_to_features
=
{
'input_ids'
:
'input_ids'
:
...
@@ -135,9 +136,11 @@ def create_pretrain_dataset(input_patterns,
...
@@ -135,9 +136,11 @@ def create_pretrain_dataset(input_patterns,
if
use_position_id
:
if
use_position_id
:
x
[
'position_ids'
]
=
record
[
'position_ids'
]
x
[
'position_ids'
]
=
record
[
'position_ids'
]
y
=
record
[
'masked_lm_weights'
]
# TODO(hongkuny): Remove the fake labels after migrating bert pretraining.
if
output_fake_labels
:
return
(
x
,
y
)
return
(
x
,
record
[
'masked_lm_weights'
])
else
:
return
x
dataset
=
dataset
.
map
(
dataset
=
dataset
.
map
(
_select_data_from_record
,
_select_data_from_record
,
...
...
official/nlp/optimization.py
View file @
3dccfae1
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Functions and classes related to optimization (weight updates)."""
"""Functions and classes related to optimization (weight updates)."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
...
@@ -21,6 +20,7 @@ from __future__ import print_function
...
@@ -21,6 +20,7 @@ from __future__ import print_function
import
re
import
re
from
absl
import
logging
from
absl
import
logging
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_addons.optimizers
as
tfa_optimizers
import
tensorflow_addons.optimizers
as
tfa_optimizers
...
@@ -67,6 +67,7 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
...
@@ -67,6 +67,7 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
}
}
@
gin
.
configurable
def
create_optimizer
(
init_lr
,
def
create_optimizer
(
init_lr
,
num_train_steps
,
num_train_steps
,
num_warmup_steps
,
num_warmup_steps
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment