Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a283424a
Commit
a283424a
authored
Mar 25, 2021
by
Frederick Liu
Committed by
A. Unique TensorFlower
Mar 25, 2021
Browse files
[nlp][translation] Opensource and register WMT transformer experiment
PiperOrigin-RevId: 365188568
parent
4a8765d5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
175 additions
and
0 deletions
+175
-0
official/nlp/configs/experiment_configs.py
official/nlp/configs/experiment_configs.py
+1
-0
official/nlp/configs/wmt_transformer_experiments.py
official/nlp/configs/wmt_transformer_experiments.py
+174
-0
No files found.
official/nlp/configs/experiment_configs.py
View file @
a283424a
...
@@ -16,3 +16,4 @@
...
@@ -16,3 +16,4 @@
# pylint: disable=unused-import
# pylint: disable=unused-import
from
official.nlp.configs
import
finetuning_experiments
from
official.nlp.configs
import
finetuning_experiments
from
official.nlp.configs
import
pretraining_experiments
from
official.nlp.configs
import
pretraining_experiments
from
official.nlp.configs
import
wmt_transformer_experiments
official/nlp/configs/wmt_transformer_experiments.py
0 → 100644
View file @
a283424a
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# pylint: disable=g-doc-return-or-yield,line-too-long
"""WMT translation configurations."""
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.modeling.progressive
import
trainer
as
prog_trainer_lib
from
official.nlp.data
import
wmt_dataloader
from
official.nlp.tasks
import
progressive_translation
from
official.nlp.tasks
import
translation
@
exp_factory
.
register_config_factory
(
'wmt_transformer/large'
)
def
wmt_transformer_large
()
->
cfg
.
ExperimentConfig
:
"""WMT Transformer Large.
Please refer to
tensorflow_models/official/nlp/data/train_sentencepiece.py
to generate sentencepiece_model
and pass
--params_override=task.sentencepiece_model_path='YOUR_PATH'
to the train script.
"""
learning_rate
=
2.0
hidden_size
=
1024
learning_rate
*=
(
hidden_size
**-
0.5
)
warmup_steps
=
16000
train_steps
=
300000
token_batch_size
=
24576
encdecoder
=
translation
.
EncDecoder
(
num_attention_heads
=
16
,
intermediate_size
=
hidden_size
*
4
)
config
=
cfg
.
ExperimentConfig
(
task
=
translation
.
TranslationConfig
(
model
=
translation
.
ModelConfig
(
encoder
=
encdecoder
,
decoder
=
encdecoder
,
embedding_width
=
hidden_size
,
padded_decode
=
True
,
decode_max_length
=
100
),
train_data
=
wmt_dataloader
.
WMTDataConfig
(
tfds_name
=
'wmt14_translate/de-en'
,
tfds_split
=
'train'
,
src_lang
=
'en'
,
tgt_lang
=
'de'
,
is_training
=
True
,
global_batch_size
=
token_batch_size
,
static_batch
=
True
,
max_seq_length
=
64
),
validation_data
=
wmt_dataloader
.
WMTDataConfig
(
tfds_name
=
'wmt14_translate/de-en'
,
tfds_split
=
'test'
,
src_lang
=
'en'
,
tgt_lang
=
'de'
,
is_training
=
False
,
global_batch_size
=
32
,
static_batch
=
True
,
max_seq_length
=
100
,
),
sentencepiece_model_path
=
None
,
),
trainer
=
cfg
.
TrainerConfig
(
train_steps
=
train_steps
,
validation_steps
=-
1
,
steps_per_loop
=
1000
,
summary_interval
=
1000
,
checkpoint_interval
=
5000
,
validation_interval
=
5000
,
max_to_keep
=
1
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adam'
,
'adam'
:
{
'beta_2'
:
0.997
,
'epsilon'
:
1e-9
,
},
},
'learning_rate'
:
{
'type'
:
'power'
,
'power'
:
{
'initial_learning_rate'
:
learning_rate
,
'power'
:
-
0.5
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
warmup_steps
,
'warmup_learning_rate'
:
0.0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.sentencepiece_model_path != None'
,
])
return
config
@
exp_factory
.
register_config_factory
(
'wmt_transformer/large_progressive'
)
def
wmt_transformer_large_progressive
()
->
cfg
.
ExperimentConfig
:
"""WMT Transformer Larger with progressive training.
Please refer to
tensorflow_models/official/nlp/data/train_sentencepiece.py
to generate sentencepiece_model
and pass
--params_override=task.sentencepiece_model_path='YOUR_PATH'
to the train script.
"""
hidden_size
=
1024
train_steps
=
300000
token_batch_size
=
24576
encdecoder
=
translation
.
EncDecoder
(
num_attention_heads
=
16
,
intermediate_size
=
hidden_size
*
4
)
config
=
cfg
.
ExperimentConfig
(
task
=
progressive_translation
.
ProgTranslationConfig
(
model
=
translation
.
ModelConfig
(
encoder
=
encdecoder
,
decoder
=
encdecoder
,
embedding_width
=
hidden_size
,
padded_decode
=
True
,
decode_max_length
=
100
),
train_data
=
wmt_dataloader
.
WMTDataConfig
(
tfds_name
=
'wmt14_translate/de-en'
,
tfds_split
=
'train'
,
src_lang
=
'en'
,
tgt_lang
=
'de'
,
is_training
=
True
,
global_batch_size
=
token_batch_size
,
static_batch
=
True
,
max_seq_length
=
64
),
validation_data
=
wmt_dataloader
.
WMTDataConfig
(
tfds_name
=
'wmt14_translate/de-en'
,
tfds_split
=
'test'
,
src_lang
=
'en'
,
tgt_lang
=
'de'
,
is_training
=
False
,
global_batch_size
=
32
,
static_batch
=
True
,
max_seq_length
=
100
,
),
sentencepiece_model_path
=
None
,
),
trainer
=
prog_trainer_lib
.
ProgressiveTrainerConfig
(
train_steps
=
train_steps
,
validation_steps
=-
1
,
steps_per_loop
=
1000
,
summary_interval
=
1000
,
checkpoint_interval
=
5000
,
validation_interval
=
5000
,
optimizer_config
=
None
,
),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.sentencepiece_model_path != None'
,
])
return
config
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment