Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c68dbef0
Commit
c68dbef0
authored
Oct 02, 2021
by
Jialu Liu
Committed by
A. Unique TensorFlower
Oct 02, 2021
Browse files
Internal change
PiperOrigin-RevId: 400408816
parent
9a1b54cd
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
46 deletions
+42
-46
official/nlp/projects/teams/experiments/base/wiki_books_pretrain.yaml
.../projects/teams/experiments/base/wiki_books_pretrain.yaml
+33
-39
official/nlp/projects/teams/teams_pretrainer.py
official/nlp/projects/teams/teams_pretrainer.py
+9
-7
No files found.
official/nlp/projects/teams/experiments/base/wiki_books_pretrain.yaml
View file @
c68dbef0
task
:
model
:
cls_heads
:
[{
activation
:
tanh
,
cls_token_idx
:
0
,
dropout_rate
:
0.1
,
inner_dim
:
768
,
name
:
next_sentence
,
num_classes
:
2
}]
generator_encoder
:
bert
:
candidate_size
:
5
num_shared_generator_hidden_layers
:
3
num_discriminator_task_agnostic_layers
:
11
tie_embeddings
:
true
generator
:
attention_dropout_rate
:
0.1
dropout_rate
:
0.1
embedding_size
:
768
hidden_activation
:
gelu
hidden_size
:
256
hidden_size
:
768
initializer_range
:
0.02
intermediate_size
:
1024
intermediate_size
:
3072
max_position_embeddings
:
512
num_attention_heads
:
4
num_layers
:
12
num_attention_heads
:
12
num_layers
:
6
type_vocab_size
:
2
vocab_size
:
30522
num_masked_tokens
:
76
sequence_length
:
512
num_classes
:
2
discriminator_encoder
:
bert
:
discriminator
:
attention_dropout_rate
:
0.1
dropout_rate
:
0.1
embedding_size
:
768
...
...
@@ -30,12 +27,9 @@ task:
intermediate_size
:
3072
max_position_embeddings
:
512
num_attention_heads
:
12
num_layers
:
12
num_layers
:
6
type_vocab_size
:
2
vocab_size
:
30522
discriminator_loss_weight
:
50.0
disallow_correct
:
false
tie_embeddings
:
true
train_data
:
drop_remainder
:
true
global_batch_size
:
256
...
...
@@ -55,8 +49,8 @@ task:
use_next_sentence_label
:
false
use_position_id
:
false
trainer
:
checkpoint_interval
:
6
000
max_to_keep
:
5
0
checkpoint_interval
:
4
000
max_to_keep
:
5
optimizer_config
:
learning_rate
:
polynomial
:
...
...
@@ -73,8 +67,8 @@ trainer:
power
:
1
warmup_steps
:
10000
type
:
polynomial
steps_per_loop
:
1
000
summary_interval
:
1
000
steps_per_loop
:
4
000
summary_interval
:
4
000
train_steps
:
1000000
validation_interval
:
100
validation_steps
:
64
official/nlp/projects/teams/teams_pretrainer.py
View file @
c68dbef0
...
...
@@ -21,6 +21,8 @@ from official.modeling import tf_utils
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
models
_LOGIT_PENALTY_MULTIPLIER
=
10000
class
ReplacedTokenDetectionHead
(
tf
.
keras
.
layers
.
Layer
):
"""Replaced token detection discriminator head.
...
...
@@ -273,10 +275,9 @@ class TeamsPretrainer(tf.keras.Model):
self
.
mlm_activation
=
mlm_activation
self
.
mlm_initializer
=
mlm_initializer
self
.
output_type
=
output_type
self
.
embedding_table
=
(
self
.
discriminator_mws_network
.
embedding_network
.
get_embedding_table
())
self
.
masked_lm
=
layers
.
MaskedLM
(
embedding_table
=
self
.
embedding_table
,
embedding_table
=
self
.
generator_network
.
embedding_network
.
get_embedding_table
(),
activation
=
mlm_activation
,
initializer
=
mlm_initializer
,
output
=
output_type
,
...
...
@@ -290,7 +291,8 @@ class TeamsPretrainer(tf.keras.Model):
name
=
'discriminator_rtd'
)
hidden_cfg
=
discriminator_cfg
[
'hidden_cfg'
]
self
.
discriminator_mws_head
=
MultiWordSelectionHead
(
embedding_table
=
self
.
embedding_table
,
embedding_table
=
self
.
discriminator_mws_network
.
embedding_network
.
get_embedding_table
(),
activation
=
hidden_cfg
[
'intermediate_activation'
],
initializer
=
hidden_cfg
[
'kernel_initializer'
],
output
=
output_type
,
...
...
@@ -436,7 +438,7 @@ def sample_k_from_softmax(logits, k, disallow=None, use_topk=False):
"""
if
use_topk
:
if
disallow
is
not
None
:
logits
-=
10000.0
*
disallow
logits
-=
_LOGIT_PENALTY_MULTIPLIER
*
disallow
uniform_noise
=
tf
.
random
.
uniform
(
tf_utils
.
get_shape_list
(
logits
),
minval
=
0
,
maxval
=
1
)
gumbel_noise
=
-
tf
.
math
.
log
(
-
tf
.
math
.
log
(
uniform_noise
+
1e-9
)
+
1e-9
)
...
...
@@ -445,7 +447,7 @@ def sample_k_from_softmax(logits, k, disallow=None, use_topk=False):
sampled_tokens_list
=
[]
vocab_size
=
tf_utils
.
get_shape_list
(
logits
)[
-
1
]
if
disallow
is
not
None
:
logits
-=
10000.0
*
disallow
logits
-=
_LOGIT_PENALTY_MULTIPLIER
*
disallow
uniform_noise
=
tf
.
random
.
uniform
(
tf_utils
.
get_shape_list
(
logits
),
minval
=
0
,
maxval
=
1
)
...
...
@@ -454,7 +456,7 @@ def sample_k_from_softmax(logits, k, disallow=None, use_topk=False):
for
_
in
range
(
k
):
token_ids
=
tf
.
argmax
(
logits
,
-
1
,
output_type
=
tf
.
int32
)
sampled_tokens_list
.
append
(
token_ids
)
logits
-=
10000.0
*
tf
.
one_hot
(
logits
-=
_LOGIT_PENALTY_MULTIPLIER
*
tf
.
one_hot
(
token_ids
,
depth
=
vocab_size
,
dtype
=
tf
.
float32
)
sampled_tokens
=
tf
.
stack
(
sampled_tokens_list
,
-
1
)
return
sampled_tokens
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment