Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7f596d87
Commit
7f596d87
authored
Dec 17, 2020
by
Le Hou
Committed by
A. Unique TensorFlower
Dec 17, 2020
Browse files
Open source the progressive training library.
PiperOrigin-RevId: 348113609
parent
d9a3b7f0
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1811 additions
and
0 deletions
+1811
-0
official/modeling/progressive/policies.py
official/modeling/progressive/policies.py
+173
-0
official/modeling/progressive/train.py
official/modeling/progressive/train.py
+68
-0
official/modeling/progressive/train_lib.py
official/modeling/progressive/train_lib.py
+126
-0
official/modeling/progressive/train_lib_test.py
official/modeling/progressive/train_lib_test.py
+121
-0
official/modeling/progressive/trainer.py
official/modeling/progressive/trainer.py
+291
-0
official/modeling/progressive/trainer_test.py
official/modeling/progressive/trainer_test.py
+241
-0
official/modeling/progressive/utils.py
official/modeling/progressive/utils.py
+37
-0
official/nlp/tasks/progressive_masked_lm.py
official/nlp/tasks/progressive_masked_lm.py
+249
-0
official/nlp/tasks/progressive_masked_lm_test.py
official/nlp/tasks/progressive_masked_lm_test.py
+112
-0
official/nlp/tasks/progressive_translation.py
official/nlp/tasks/progressive_translation.py
+242
-0
official/nlp/tasks/progressive_translation_test.py
official/nlp/tasks/progressive_translation_test.py
+151
-0
No files found.
official/modeling/progressive/policies.py
0 → 100644
View file @
7f596d87
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base ProgressivePolicy definition for progressive training.
To write a progressive model, subclass ProgressivePolicy and implement its
abstract methods to handle each training stage.
"""
import
abc
from
typing
import
Any
,
Mapping
from
absl
import
logging
import
dataclasses
import
six
import
tensorflow
as
tf
from
official.modeling.hyperparams
import
base_config
from
official.modeling.progressive
import
utils
@
dataclasses
.
dataclass
class
ProgressiveConfig
(
base_config
.
Config
):
pass
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
ProgressivePolicy
:
"""The APIs for handling progressive training stages.
Attributes:
cur_model: The model for the current progressive training stage.
cur_train_dataset: The train dataset function for the current stage.
cur_eval_dataset: The eval dataset function for the current stage.
cur_optimizer: The optimizer for the current stage.
cur_checkpoint_items: Items to be saved in and restored from checkpoints,
for the progressive trainer.
is_last_stage: Whether it is currently in the last stage.
Interfaces:
is_stage_advancing: Returns if progressive training is advancing to the
next stage.
update_pt_stage: Update progressive training stage.
"""
def
__init__
(
self
):
"""Initialize stage policy."""
self
.
_cur_train_dataset
=
None
self
.
_cur_eval_dataset
=
None
self
.
_volatiles
=
utils
.
VolatileTrackable
(
optimizer
=
None
,
model
=
None
)
stage_id
=
0
self
.
_stage_id
=
tf
.
Variable
(
stage_id
,
trainable
=
False
,
dtype
=
tf
.
int64
,
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
,
shape
=
[])
self
.
_volatiles
.
reassign_trackable
(
optimizer
=
self
.
get_optimizer
(
stage_id
),
model
=
self
.
get_model
(
stage_id
,
old_model
=
None
))
def
compute_stage_id
(
self
,
global_step
:
int
)
->
int
:
for
stage_id
in
range
(
self
.
num_stages
()):
global_step
-=
self
.
num_steps
(
stage_id
)
if
global_step
<
0
:
return
stage_id
logging
.
error
(
'Global step %d found no matching progressive stages. '
'Default to the last stage.'
,
global_step
)
return
self
.
num_stages
()
-
1
@
abc
.
abstractmethod
def
num_stages
(
self
)
->
int
:
"""Return the total number of progressive stages."""
pass
@
abc
.
abstractmethod
def
num_steps
(
self
,
stage_id
:
int
)
->
int
:
"""Return the total number of steps in this stage."""
pass
@
abc
.
abstractmethod
def
get_model
(
self
,
stage_id
:
int
,
old_model
:
tf
.
keras
.
Model
=
None
)
->
tf
.
keras
.
Model
:
"""Return model for this stage. For initialization, `old_model` = None."""
pass
@
abc
.
abstractmethod
def
get_optimizer
(
self
,
stage_id
:
int
)
->
tf
.
keras
.
optimizers
.
Optimizer
:
"""Return optimizer for this stage."""
pass
@
abc
.
abstractmethod
def
get_train_dataset
(
self
,
stage_id
:
int
)
->
tf
.
data
.
Dataset
:
"""Return training Dataset for this stage."""
pass
@
abc
.
abstractmethod
def
get_eval_dataset
(
self
,
stage_id
:
int
)
->
tf
.
data
.
Dataset
:
"""Return evaluation Dataset for this stage."""
pass
@
property
def
cur_model
(
self
)
->
tf
.
keras
.
Model
:
return
self
.
_volatiles
.
model
@
property
def
cur_train_dataset
(
self
)
->
tf
.
data
.
Dataset
:
if
self
.
_cur_train_dataset
is
None
:
self
.
_cur_train_dataset
=
self
.
get_train_dataset
(
self
.
_stage_id
.
numpy
())
return
self
.
_cur_train_dataset
@
property
def
cur_eval_dataset
(
self
)
->
tf
.
data
.
Dataset
:
if
self
.
_cur_eval_dataset
is
None
:
self
.
_cur_eval_dataset
=
self
.
get_eval_dataset
(
self
.
_stage_id
.
numpy
())
return
self
.
_cur_eval_dataset
@
property
def
cur_optimizer
(
self
)
->
tf
.
keras
.
optimizers
.
Optimizer
:
return
self
.
_volatiles
.
optimizer
@
property
def
is_last_stage
(
self
)
->
bool
:
stage_id
=
self
.
_stage_id
.
numpy
()
return
stage_id
>=
self
.
num_stages
()
-
1
@
property
def
cur_checkpoint_items
(
self
)
->
Mapping
[
str
,
Any
]:
return
dict
(
stage_id
=
self
.
_stage_id
,
volatiles
=
self
.
_volatiles
)
def
is_stage_advancing
(
self
,
global_step
:
int
)
->
bool
:
old_stage_id
=
self
.
_stage_id
.
numpy
()
new_stage_id
=
self
.
compute_stage_id
(
global_step
)
return
old_stage_id
!=
new_stage_id
def
update_pt_stage
(
self
,
global_step
:
int
,
pass_old_model
=
True
)
->
None
:
"""Update progressive training internal status.
Call this after a training loop ends.
Args:
global_step: an integer scalar of the current global step.
pass_old_model: whether to pass the old_model to get_model() function.
This is set to False if the old_model is irrelevant (e.g, just a default
model from stage 0).
"""
old_stage_id
=
self
.
_stage_id
.
numpy
()
new_stage_id
=
self
.
compute_stage_id
(
global_step
)
logging
.
info
(
'Switching stage from %d to %d'
,
old_stage_id
,
new_stage_id
)
# Update stage id.
self
.
_stage_id
.
assign
(
new_stage_id
)
# Update dataset function.
self
.
_cur_train_dataset
=
None
self
.
_cur_eval_dataset
=
None
# Update optimizer and model.
new_optimizer
=
self
.
get_optimizer
(
new_stage_id
)
self
.
_volatiles
.
reassign_trackable
(
optimizer
=
new_optimizer
)
new_model
=
self
.
get_model
(
new_stage_id
,
old_model
=
self
.
cur_model
if
pass_old_model
else
None
)
self
.
_volatiles
.
reassign_trackable
(
model
=
new_model
)
official/modeling/progressive/train.py
0 → 100644
View file @
7f596d87
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFM binary for the progressive trainer."""
from
absl
import
app
from
absl
import
flags
import
gin
from
official.common
import
distribute_utils
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: enable=unused-import
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.modeling.progressive
import
train_lib
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
,
params
.
runtime
.
loss_scale
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
,
**
params
.
runtime
.
model_parallelism
())
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
main
)
official/modeling/progressive/train_lib.py
0 → 100644
View file @
7f596d87
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFM progressive training driver library.
Compared to the common training driver, the only difference is that we use
prog_trainer_lib.ProgressiveTrainer instead of the base trainer.
"""
# pytype: disable=attribute-error
import
os
from
typing
import
Any
,
Mapping
,
Tuple
# Import libraries
from
absl
import
logging
import
orbit
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
from
official.core
import
train_lib
as
base_train_lib
from
official.modeling.progressive
import
trainer
as
prog_trainer_lib
def
run_experiment
(
distribution_strategy
:
tf
.
distribute
.
Strategy
,
task
:
base_task
.
Task
,
mode
:
str
,
params
:
config_definitions
.
ExperimentConfig
,
model_dir
:
str
,
run_post_eval
:
bool
=
False
,
save_summary
:
bool
=
True
)
\
->
Tuple
[
tf
.
keras
.
Model
,
Mapping
[
str
,
Any
]]:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
with
distribution_strategy
.
scope
():
logging
.
info
(
'Running progressive trainer.'
)
trainer
=
prog_trainer_lib
.
ProgressiveTrainer
(
params
,
task
,
ckpt_dir
=
model_dir
,
train
=
'train'
in
mode
,
evaluate
=
(
'eval'
in
mode
)
or
run_post_eval
,
checkpoint_exporter
=
base_train_lib
.
maybe_create_best_ckpt_exporter
(
params
,
model_dir
))
if
trainer
.
checkpoint
:
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
trainer
.
checkpoint
,
directory
=
model_dir
,
max_to_keep
=
params
.
trainer
.
max_to_keep
,
step_counter
=
trainer
.
global_step
,
checkpoint_interval
=
params
.
trainer
.
checkpoint_interval
,
init_fn
=
trainer
.
initialize
)
else
:
checkpoint_manager
=
None
controller
=
orbit
.
Controller
(
strategy
=
distribution_strategy
,
trainer
=
trainer
if
'train'
in
mode
else
None
,
evaluator
=
trainer
,
global_step
=
trainer
.
global_step
,
steps_per_loop
=
params
.
trainer
.
steps_per_loop
,
checkpoint_manager
=
checkpoint_manager
,
summary_dir
=
os
.
path
.
join
(
model_dir
,
'train'
)
if
(
save_summary
)
else
None
,
eval_summary_dir
=
os
.
path
.
join
(
model_dir
,
'validation'
)
if
(
save_summary
)
else
None
,
summary_interval
=
params
.
trainer
.
summary_interval
if
(
save_summary
)
else
None
)
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
with
distribution_strategy
.
scope
():
if
mode
==
'train'
:
controller
.
train
(
steps
=
params
.
trainer
.
train_steps
)
elif
mode
==
'train_and_eval'
:
controller
.
train_and_evaluate
(
train_steps
=
params
.
trainer
.
train_steps
,
eval_steps
=
params
.
trainer
.
validation_steps
,
eval_interval
=
params
.
trainer
.
validation_interval
)
elif
mode
==
'eval'
:
controller
.
evaluate
(
steps
=
params
.
trainer
.
validation_steps
)
elif
mode
==
'continuous_eval'
:
def
timeout_fn
():
if
trainer
.
global_step
.
numpy
()
>=
params
.
trainer
.
train_steps
:
return
True
return
False
controller
.
evaluate_continuously
(
steps
=
params
.
trainer
.
validation_steps
,
timeout
=
params
.
trainer
.
continuous_eval_timeout
,
timeout_fn
=
timeout_fn
)
else
:
raise
NotImplementedError
(
'The mode is not implemented: %s'
%
mode
)
if
run_post_eval
:
with
distribution_strategy
.
scope
():
return
trainer
.
model
,
trainer
.
evaluate
(
tf
.
convert_to_tensor
(
params
.
trainer
.
validation_steps
))
else
:
return
trainer
.
model
,
{}
official/modeling/progressive/train_lib_test.py
0 → 100644
View file @
7f596d87
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the progressive train_lib."""
import
os
from
absl
import
flags
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.common
import
flags
as
tfm_flags
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: enable=unused-import
from
official.core
import
config_definitions
as
cfg
from
official.core
import
task_factory
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.progressive
import
train_lib
from
official.modeling.progressive
import
trainer
as
prog_trainer_lib
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.tasks
import
progressive_masked_lm
FLAGS
=
flags
.
FLAGS
tfm_flags
.
define_flags
()
class
TrainTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
TrainTest
,
self
).
setUp
()
self
.
_test_config
=
{
'trainer'
:
{
'checkpoint_interval'
:
10
,
'steps_per_loop'
:
10
,
'summary_interval'
:
10
,
'train_steps'
:
10
,
'validation_steps'
:
5
,
'validation_interval'
:
10
,
'continuous_eval_timeout'
:
1
,
'optimizer_config'
:
{
'optimizer'
:
{
'type'
:
'sgd'
,
},
'learning_rate'
:
{
'type'
:
'constant'
}
}
},
}
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
'eager'
,
flag_mode
=
[
'train'
,
'eval'
,
'train_and_eval'
],
run_post_eval
=
[
True
,
False
]))
def
test_end_to_end
(
self
,
distribution_strategy
,
flag_mode
,
run_post_eval
):
model_dir
=
self
.
get_temp_dir
()
experiment_config
=
cfg
.
ExperimentConfig
(
trainer
=
prog_trainer_lib
.
ProgressiveTrainerConfig
(),
task
=
progressive_masked_lm
.
ProgMaskedLMConfig
(
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
input_path
=
'dummy'
),
validation_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
is_training
=
False
,
input_path
=
'dummy'
)))
experiment_config
=
params_dict
.
override_params_dict
(
experiment_config
,
self
.
_test_config
,
is_strict
=
False
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
experiment_config
.
task
,
logging_dir
=
model_dir
)
_
,
logs
=
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
flag_mode
,
params
=
experiment_config
,
model_dir
=
model_dir
,
run_post_eval
=
run_post_eval
)
if
run_post_eval
:
self
.
assertNotEmpty
(
logs
)
else
:
self
.
assertEmpty
(
logs
)
if
flag_mode
==
'eval'
:
return
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
model_dir
,
'checkpoint'
)))
# Tests continuous evaluation.
_
,
logs
=
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
'continuous_eval'
,
params
=
experiment_config
,
model_dir
=
model_dir
,
run_post_eval
=
run_post_eval
)
print
(
logs
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/modeling/progressive/trainer.py
0 → 100644
View file @
7f596d87
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Progressive Trainer implementation.
The trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
"""
import
os
from
typing
import
Any
,
Optional
# Import libraries
from
absl
import
logging
import
dataclasses
import
gin
import
orbit
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
base_trainer
as
trainer_lib
from
official.core
import
config_definitions
from
official.modeling.progressive
import
policies
ExperimentConfig
=
config_definitions
.
ExperimentConfig
@
dataclasses
.
dataclass
class
ProgressiveTrainerConfig
(
config_definitions
.
TrainerConfig
):
"""Configuration for progressive trainer.
Attributes:
progressive: A task-specific config. Users can subclass ProgressiveConfig
and define any task-specific settings in their subclass.
export_checkpoint: A bool. Whether to export checkpoints in non-progressive
manner (without the volatiles wrapper) such that your down-stream tasks
can load checkpoints from a progressive trainer as if it is a regular
checkpoint.
export_checkpoint_interval: A bool. The number of steps between exporting
checkpoints. If None (by default), will use the same value as
TrainerConfig.checkpoint_interval.
export_only_final_stage_ckpt: A bool. Whether to just export checkpoints
during the final progressive training stage. In other words, whether to
not export small, partial models. In many cases, it is not meaningful to
finetune a small, partial model in down-stream tasks.
"""
progressive
:
Optional
[
policies
.
ProgressiveConfig
]
=
None
export_checkpoint
:
bool
=
True
export_checkpoint_interval
:
Optional
[
int
]
=
None
export_only_final_stage_ckpt
:
bool
=
True
class
CheckpointWithHooks
(
tf
.
train
.
Checkpoint
):
"""Same as tf.train.Checkpoint but supports hooks.
When running continuous_eval jobs, when a new checkpoint arrives, we have to
update our model and optimizer etc. to match the stage_id of the checkpoint.
However, when orbit loads a checkpoint, it does not inform us. So we use this
class to update our model to the correct stage before checkpoint restore.
"""
def
__init__
(
self
,
before_load_hook
,
**
kwargs
):
self
.
_before_load_hook
=
before_load_hook
super
(
CheckpointWithHooks
,
self
).
__init__
(
**
kwargs
)
# override
def
read
(
self
,
save_path
,
options
=
None
):
self
.
_before_load_hook
(
save_path
)
logging
.
info
(
'Ran before_load_hook.'
)
super
(
CheckpointWithHooks
,
self
).
read
(
save_path
=
save_path
,
options
=
options
)
@
gin
.
configurable
class
ProgressiveTrainer
(
trainer_lib
.
Trainer
):
"""Implements the progressive trainer shared for TensorFlow models."""
def
__init__
(
self
,
config
:
ExperimentConfig
,
prog_task
:
base_task
.
Task
,
# also implemented ProgressivePolicy.
ckpt_dir
:
str
=
''
,
train
:
bool
=
True
,
evaluate
:
bool
=
True
,
checkpoint_exporter
:
Any
=
None
):
"""Initialize common trainer for TensorFlow models.
Args:
config: An `ExperimentConfig` instance specifying experiment config.
prog_task: An instance both implemented policies.ProgressivePolicy and
base_task.Task.
ckpt_dir: Checkpoint directory.
train: bool, whether or not this trainer will be used for training.
default to True.
evaluate: bool, whether or not this trainer will be used for evaluation.
default to True.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_config
=
config
self
.
_task
=
prog_task
# Directory for non-progressive checkpoint
self
.
_export_ckpt_dir
=
os
.
path
.
join
(
ckpt_dir
,
'exported_ckpts'
)
tf
.
io
.
gfile
.
makedirs
(
self
.
_export_ckpt_dir
)
# Receive other checkpoint export, e.g, best checkpoint exporter.
# TODO(lehou): unify the checkpoint exporting logic, although the default
# setting does not use checkpoint_exporter.
self
.
_checkpoint_exporter
=
checkpoint_exporter
self
.
_global_step
=
orbit
.
utils
.
create_global_step
()
self
.
_checkpoint
=
CheckpointWithHooks
(
before_load_hook
=
self
.
_update_pt_stage_from_ckpt
,
global_step
=
self
.
global_step
,
**
self
.
_task
.
cur_checkpoint_items
)
self
.
_train_loss
=
tf
.
keras
.
metrics
.
Mean
(
'training_loss'
,
dtype
=
tf
.
float32
)
self
.
_validation_loss
=
tf
.
keras
.
metrics
.
Mean
(
'validation_loss'
,
dtype
=
tf
.
float32
)
self
.
_train_metrics
=
self
.
task
.
build_metrics
(
training
=
True
)
+
self
.
model
.
metrics
self
.
_validation_metrics
=
self
.
task
.
build_metrics
(
training
=
False
)
+
self
.
model
.
metrics
if
train
:
orbit
.
StandardTrainer
.
__init__
(
self
,
None
,
# Manage train_dataset by ourselves, not by StandardTrainer.
options
=
orbit
.
StandardTrainerOptions
(
use_tf_while_loop
=
config
.
trainer
.
train_tf_while_loop
,
use_tf_function
=
config
.
trainer
.
train_tf_function
))
if
evaluate
:
orbit
.
StandardEvaluator
.
__init__
(
self
,
None
,
# Manage train_dataset by ourselves, not by StandardEvaluator.
options
=
orbit
.
StandardEvaluatorOptions
(
use_tf_function
=
config
.
trainer
.
eval_tf_function
))
@
property
def
model
(
self
):
return
self
.
_task
.
cur_model
@
property
def
optimizer
(
self
):
return
self
.
_task
.
cur_optimizer
# override
@
property
def
train_dataset
(
self
):
"""Overriding StandardTrainer.train_dataset."""
return
self
.
_task
.
cur_train_dataset
# override
@
train_dataset
.
setter
def
train_dataset
(
self
,
_
):
raise
SyntaxError
(
'Please do not set train_dataset. Progressive training '
'relies on progressive policy to manager train dataset.'
)
# override
@
property
def
eval_dataset
(
self
):
"""Overriding StandardEvaluator.eval_dataset."""
return
self
.
_task
.
cur_eval_dataset
# override
@
eval_dataset
.
setter
def
eval_dataset
(
self
,
_
):
raise
SyntaxError
(
'Please do not set eval_dataset. Progressive training '
'relies on progressive policy to manager eval dataset.'
)
def
train_loop_end
(
self
):
"""See base class."""
logs
=
{}
for
metric
in
self
.
train_metrics
+
[
self
.
train_loss
]:
logs
[
metric
.
name
]
=
metric
.
result
()
metric
.
reset_states
()
if
callable
(
self
.
optimizer
.
learning_rate
):
logs
[
'learning_rate'
]
=
self
.
optimizer
.
learning_rate
(
self
.
optimizer
.
iterations
)
else
:
logs
[
'learning_rate'
]
=
self
.
optimizer
.
learning_rate
self
.
_maybe_export_non_progressive_checkpoint
(
self
.
_export_ckpt_dir
)
if
self
.
_task
.
is_stage_advancing
(
self
.
global_step
.
numpy
()):
old_train_dataset
=
self
.
train_dataset
# Update progressive properties
self
.
_task
.
update_pt_stage
(
self
.
global_step
.
numpy
())
# Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will
# rebuild the train and eval functions with the updated model.
self
.
_train_loop_fn
=
None
self
.
_eval_loop_fn
=
None
if
self
.
train_dataset
!=
old_train_dataset
:
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self
.
_train_iter
=
None
return
logs
def
_update_pt_stage_from_ckpt
(
self
,
ckpt_file
):
"""Update stage properties based on the global_step variable in a ckpt file.
Before loading variables from a checkpoint file, we need to go to the
correct stage and build corresponding model and optimizer, to make sure that
we retore variables of the right model and optimizer.
Args:
ckpt_file: Checkpoint file that will be restored/read from.
"""
if
not
ckpt_file
:
return
ckpt
=
tf
.
train
.
Checkpoint
(
global_step
=
self
.
global_step
)
ckpt
.
read
(
ckpt_file
).
expect_partial
().
assert_existing_objects_matched
()
if
self
.
_task
.
is_stage_advancing
(
self
.
global_step
.
numpy
()):
old_train_dataset
=
self
.
train_dataset
# Update progressive properties
self
.
_task
.
update_pt_stage
(
self
.
global_step
.
numpy
(),
pass_old_model
=
False
)
# Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will
# rebuild the train and eval functions with the updated model.
self
.
_train_loop_fn
=
None
self
.
_eval_loop_fn
=
None
if
self
.
train_dataset
!=
old_train_dataset
:
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self
.
_train_iter
=
None
def
_maybe_export_non_progressive_checkpoint
(
self
,
export_ckpt_dir
):
"""Export checkpoints in non-progressive format.
This basically removes the wrapping of self._task.cur_checkpoint_items
-- just save the model, optimizer, etc., directly.
The purpose is to let your down-stream tasks to use these checkpoints.
Args:
export_ckpt_dir: A str. folder of exported checkpoints.
"""
if
not
self
.
config
.
trainer
.
export_checkpoint
:
logging
.
info
(
'Not exporting checkpoints.'
)
return
if
not
self
.
_task
.
is_last_stage
and
(
self
.
config
.
trainer
.
export_only_final_stage_ckpt
):
logging
.
info
(
'Not exporting checkpoints until the last stage.'
)
return
global_step_np
=
self
.
global_step
.
numpy
()
if
self
.
config
.
trainer
.
export_checkpoint_interval
is
None
:
step_interval
=
self
.
config
.
trainer
.
checkpoint_interval
else
:
step_interval
=
self
.
config
.
trainer
.
export_checkpoint_interval
if
global_step_np
%
step_interval
!=
0
:
logging
.
info
(
'Not exporting checkpoints in global step: %d.'
,
global_step_np
)
return
# Create a checkpoint object just now, to make sure we use
# progressive_policy.cur_model and progressive_policy.cur_optimizer of the
# current stage.
if
hasattr
(
self
.
model
,
'checkpoint_items'
):
checkpoint_items
=
self
.
model
.
checkpoint_items
else
:
checkpoint_items
=
{}
checkpoint
=
tf
.
train
.
Checkpoint
(
global_step
=
self
.
global_step
,
model
=
self
.
model
,
optimizer
=
self
.
optimizer
,
**
checkpoint_items
)
file_prefix
=
os
.
path
.
join
(
export_ckpt_dir
,
'ckpt-{}'
.
format
(
global_step_np
))
checkpoint
.
save
(
file_prefix
=
file_prefix
)
logging
.
info
(
'Checkpoints exported: %s.'
,
file_prefix
)
official/modeling/progressive/trainer_test.py
0 → 100644
View file @
7f596d87
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the progressive trainer."""
# pylint: disable=g-direct-tensorflow-import
import
os
from
absl.testing
import
parameterized
import
orbit
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.core
import
config_definitions
as
cfg
from
official.modeling
import
optimization
from
official.modeling.progressive
import
policies
from
official.modeling.progressive
import
trainer
as
trainer_lib
from
official.nlp.configs
import
bert
from
official.utils.testing
import
mock_task
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
'eager'
,
)
def
get_exp_config
():
return
cfg
.
ExperimentConfig
(
task
=
cfg
.
TaskConfig
(
model
=
bert
.
PretrainerConfig
()),
trainer
=
trainer_lib
.
ProgressiveTrainerConfig
(
export_checkpoint
=
True
,
export_checkpoint_interval
=
1
,
export_only_final_stage_ckpt
=
False
))
class
TestPolicy
(
policies
.
ProgressivePolicy
,
mock_task
.
MockTask
):
"""Just for testing purposes."""
def
__init__
(
self
,
strategy
,
task_config
,
change_train_dataset
=
True
):
self
.
_strategy
=
strategy
self
.
_change_train_dataset
=
change_train_dataset
self
.
_my_train_dataset
=
None
mock_task
.
MockTask
.
__init__
(
self
,
params
=
task_config
,
logging_dir
=
None
)
policies
.
ProgressivePolicy
.
__init__
(
self
)
def
num_stages
(
self
)
->
int
:
return
2
def
num_steps
(
self
,
stage_id
:
int
)
->
int
:
return
2
if
stage_id
==
0
else
4
def
get_model
(
self
,
stage_id
:
int
,
old_model
:
tf
.
keras
.
Model
)
->
tf
.
keras
.
Model
:
del
stage_id
,
old_model
return
self
.
build_model
()
def
get_optimizer
(
self
,
stage_id
:
int
)
->
tf
.
keras
.
optimizers
.
Optimizer
:
optimizer_type
=
'sgd'
if
stage_id
==
0
else
'adamw'
optimizer_config
=
cfg
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
optimizer_type
},
'learning_rate'
:
{
'type'
:
'constant'
}})
opt_factory
=
optimization
.
OptimizerFactory
(
optimizer_config
)
return
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
def
get_train_dataset
(
self
,
stage_id
:
int
)
->
tf
.
data
.
Dataset
:
if
not
self
.
_change_train_dataset
and
self
.
_my_train_dataset
:
return
self
.
_my_train_dataset
if
self
.
_strategy
:
self
.
_my_train_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
_strategy
,
self
.
_build_inputs
,
stage_id
)
else
:
self
.
_my_train_dataset
=
self
.
_build_inputs
(
stage_id
)
return
self
.
_my_train_dataset
def
get_eval_dataset
(
self
,
stage_id
:
int
)
->
tf
.
data
.
Dataset
:
if
self
.
_strategy
:
return
orbit
.
utils
.
make_distributed_dataset
(
self
.
_strategy
,
self
.
_build_inputs
,
stage_id
)
return
self
.
_build_inputs
(
stage_id
)
def
_build_inputs
(
self
,
stage_id
):
def
dummy_data
(
_
):
batch_size
=
2
if
stage_id
==
0
else
1
x
=
tf
.
zeros
(
shape
=
(
batch_size
,
2
),
dtype
=
tf
.
float32
)
label
=
tf
.
zeros
(
shape
=
(
batch_size
,
1
),
dtype
=
tf
.
float32
)
return
x
,
label
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
return
dataset
.
map
(
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
class
TrainerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
TrainerTest
,
self
).
setUp
()
self
.
_config
=
get_exp_config
()
def
create_test_trainer
(
self
,
distribution
,
model_dir
,
change_train_dataset
):
trainer
=
trainer_lib
.
ProgressiveTrainer
(
self
.
_config
,
prog_task
=
TestPolicy
(
distribution
,
self
.
_config
.
task
,
change_train_dataset
),
ckpt_dir
=
model_dir
)
return
trainer
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_checkpointing
(
self
,
distribution
):
model_dir
=
self
.
get_temp_dir
()
ckpt_file
=
os
.
path
.
join
(
model_dir
,
'ckpt'
)
with
distribution
.
scope
():
trainer
=
self
.
create_test_trainer
(
distribution
,
model_dir
,
True
)
self
.
assertFalse
(
trainer
.
_task
.
is_last_stage
)
trainer
.
train
(
tf
.
convert_to_tensor
(
4
,
dtype
=
tf
.
int32
))
self
.
assertTrue
(
trainer
.
_task
.
is_last_stage
)
trainer
.
checkpoint
.
save
(
ckpt_file
)
trainer
=
self
.
create_test_trainer
(
distribution
,
model_dir
,
True
)
self
.
assertFalse
(
trainer
.
_task
.
is_last_stage
)
trainer
.
checkpoint
.
restore
(
ckpt_file
+
'-1'
)
self
.
assertTrue
(
trainer
.
_task
.
is_last_stage
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_train_dataset
(
self
,
distribution
):
model_dir
=
self
.
get_temp_dir
()
with
distribution
.
scope
():
trainer
=
self
.
create_test_trainer
(
distribution
,
model_dir
,
True
)
# Using dataset of stage == 0
train_iter
=
tf
.
nest
.
map_structure
(
iter
,
trainer
.
train_dataset
)
train_data
=
train_iter
.
next
()[
0
]
if
distribution
.
num_replicas_in_sync
>
1
:
train_data
=
train_data
.
values
[
0
]
self
.
assertEqual
(
train_data
.
shape
[
0
],
2
)
trainer
.
train
(
tf
.
convert_to_tensor
(
4
,
dtype
=
tf
.
int32
))
# Using dataset of stage == 1
train_iter
=
tf
.
nest
.
map_structure
(
iter
,
trainer
.
train_dataset
)
train_data
=
train_iter
.
next
()[
0
]
if
distribution
.
num_replicas_in_sync
>
1
:
train_data
=
train_data
.
values
[
0
]
self
.
assertEqual
(
train_data
.
shape
[
0
],
1
)
with
self
.
assertRaises
(
SyntaxError
):
trainer
.
train_dataset
=
None
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_train_dataset_no_switch
(
self
,
distribution
):
model_dir
=
self
.
get_temp_dir
()
with
distribution
.
scope
():
trainer
=
self
.
create_test_trainer
(
distribution
,
model_dir
,
False
)
trainer
.
train
(
tf
.
convert_to_tensor
(
2
,
dtype
=
tf
.
int32
))
# _train_iter is not reset since the dataset is not changed.
self
.
assertIsNotNone
(
trainer
.
_train_iter
)
with
distribution
.
scope
():
trainer
=
self
.
create_test_trainer
(
distribution
,
model_dir
,
True
)
trainer
.
train
(
tf
.
convert_to_tensor
(
2
,
dtype
=
tf
.
int32
))
# _train_iter is reset since the dataset changed.
self
.
assertIsNone
(
trainer
.
_train_iter
)
class
TrainerWithMaskedLMTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
TrainerWithMaskedLMTaskTest
,
self
).
setUp
()
self
.
_config
=
get_exp_config
()
def
create_test_trainer
(
self
,
distribution
):
trainer
=
trainer_lib
.
ProgressiveTrainer
(
self
.
_config
,
prog_task
=
TestPolicy
(
distribution
,
self
.
_config
.
task
),
ckpt_dir
=
self
.
get_temp_dir
())
return
trainer
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_train
(
self
,
distribution
):
with
distribution
.
scope
():
trainer
=
self
.
create_test_trainer
(
distribution
)
logs
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'training_loss'
,
logs
)
self
.
assertIn
(
'learning_rate'
,
logs
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_validate
(
self
,
distribution
):
with
distribution
.
scope
():
trainer
=
self
.
create_test_trainer
(
distribution
)
logs
=
trainer
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'validation_loss'
,
logs
)
self
.
assertEqual
(
logs
[
'counter'
],
5.
*
distribution
.
num_replicas_in_sync
)
@
combinations
.
generate
(
combinations
.
combine
(
mixed_precision_dtype
=
[
'float32'
,
'bfloat16'
,
'float16'
],
loss_scale
=
[
None
,
'dynamic'
,
128
,
256
],
))
def
test_configure_optimizer
(
self
,
mixed_precision_dtype
,
loss_scale
):
config
=
cfg
.
ExperimentConfig
(
task
=
cfg
.
TaskConfig
(
model
=
bert
.
PretrainerConfig
()),
runtime
=
cfg
.
RuntimeConfig
(
mixed_precision_dtype
=
mixed_precision_dtype
,
loss_scale
=
loss_scale
),
trainer
=
trainer_lib
.
ProgressiveTrainerConfig
(
export_checkpoint
=
True
,
export_checkpoint_interval
=
1
,
export_only_final_stage_ckpt
=
False
))
task
=
TestPolicy
(
None
,
config
.
task
)
trainer
=
trainer_lib
.
ProgressiveTrainer
(
config
,
task
,
self
.
get_temp_dir
())
if
mixed_precision_dtype
!=
'float16'
:
self
.
assertIsInstance
(
trainer
.
optimizer
,
tf
.
keras
.
optimizers
.
SGD
)
elif
mixed_precision_dtype
==
'float16'
and
loss_scale
is
None
:
self
.
assertIsInstance
(
trainer
.
optimizer
,
tf
.
keras
.
optimizers
.
SGD
)
metrics
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'training_loss'
,
metrics
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/modeling/progressive/utils.py
0 → 100644
View file @
7f596d87
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Util classes and functions."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.training.tracking
import
tracking
class
VolatileTrackable
(
tracking
.
AutoTrackable
):
"""A util class to keep Trackables that might change instances."""
def
__init__
(
self
,
**
kwargs
):
for
k
,
v
in
kwargs
.
items
():
setattr
(
self
,
k
,
v
)
def
reassign_trackable
(
self
,
**
kwargs
):
for
k
,
v
in
kwargs
.
items
():
delattr
(
self
,
k
)
# untrack this object
setattr
(
self
,
k
,
v
)
# track the new object
official/nlp/tasks/progressive_masked_lm.py
0 → 100644
View file @
7f596d87
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Masked language task with progressive training."""
from
typing
import
List
# Import libraries
from
absl
import
logging
import
dataclasses
import
orbit
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
task_factory
from
official.modeling
import
optimization
from
official.modeling.hyperparams
import
base_config
from
official.modeling.progressive
import
policies
from
official.nlp.tasks
import
masked_lm
@
dataclasses
.
dataclass
class
StackingStageConfig
(
base_config
.
Config
):
num_layers
:
int
=
0
num_steps
:
int
=
0
warmup_steps
:
int
=
10000
initial_learning_rate
:
float
=
1e-4
end_learning_rate
:
float
=
0.0
decay_steps
:
int
=
1000000
@
dataclasses
.
dataclass
class
ProgMaskedLMConfig
(
masked_lm
.
MaskedLMConfig
):
"""The progressive model config."""
optimizer_config
:
optimization
.
OptimizationConfig
=
(
optimization
.
OptimizationConfig
(
optimizer
=
optimization
.
OptimizerConfig
(
type
=
'adamw'
),
learning_rate
=
optimization
.
LrConfig
(
type
=
'polynomial'
),
warmup
=
optimization
.
WarmupConfig
(
type
=
'polynomial'
),
)
)
stage_list
:
List
[
StackingStageConfig
]
=
dataclasses
.
field
(
default_factory
=
lambda
:
[
# pylint: disable=g-long-lambda
StackingStageConfig
(
num_layers
=
3
,
num_steps
=
112500
,
warmup_steps
=
10000
,
initial_learning_rate
=
1e-4
,
end_learning_rate
=
1e-4
,
decay_steps
=
112500
),
StackingStageConfig
(
num_layers
=
6
,
num_steps
=
112500
,
warmup_steps
=
10000
,
initial_learning_rate
=
1e-4
,
end_learning_rate
=
1e-4
,
decay_steps
=
112500
),
StackingStageConfig
(
num_layers
=
12
,
num_steps
=
450000
,
warmup_steps
=
10000
,
initial_learning_rate
=
1e-4
,
end_learning_rate
=
0.0
,
decay_steps
=
450000
)])
@
task_factory
.
register_task_cls
(
ProgMaskedLMConfig
)
class
ProgressiveMaskedLM
(
policies
.
ProgressivePolicy
,
masked_lm
.
MaskedLMTask
):
"""Masked Language Model that supports progressive training.
Inherate from the MaskedLmTask class to build model datasets etc.
"""
def
__init__
(
self
,
params
:
cfg
.
TaskConfig
,
logging_dir
:
str
=
None
):
masked_lm
.
MaskedLMTask
.
__init__
(
self
,
params
=
params
,
logging_dir
=
logging_dir
)
self
.
_model_config
=
params
.
model
self
.
_optimizer_config
=
params
.
optimizer_config
self
.
_the_only_train_dataset
=
None
self
.
_the_only_eval_dataset
=
None
policies
.
ProgressivePolicy
.
__init__
(
self
)
# Override
def
num_stages
(
self
):
return
len
(
self
.
task_config
.
stage_list
)
# Override
def
num_steps
(
self
,
stage_id
):
return
self
.
task_config
.
stage_list
[
stage_id
].
num_steps
# Override
def
get_model
(
self
,
stage_id
,
old_model
=
None
):
"""Build model for each stage."""
num_layers
=
self
.
task_config
.
stage_list
[
stage_id
].
num_layers
encoder_type
=
self
.
_model_config
.
encoder
.
type
params
=
self
.
_model_config
.
replace
(
encoder
=
{
encoder_type
:
{
'num_layers'
:
num_layers
}})
model
=
self
.
build_model
(
params
)
# Run the model once, to make sure that all layers are built.
# Otherwise, not all weights will be copied.
_
=
model
(
model
.
inputs
)
if
stage_id
>
0
and
old_model
is
not
None
:
logging
.
info
(
'Stage %d copying weights.'
,
stage_id
)
self
.
_copy_weights_to_new_model
(
old_model
=
old_model
,
new_model
=
model
)
return
model
# Override
def
get_optimizer
(
self
,
stage_id
):
"""Build optimizer for each stage."""
params
=
self
.
_optimizer_config
.
replace
(
learning_rate
=
{
'polynomial'
:
{
'decay_steps'
:
self
.
task_config
.
stage_list
[
stage_id
].
decay_steps
,
'initial_learning_rate'
:
self
.
task_config
.
stage_list
[
stage_id
].
initial_learning_rate
,
'end_learning_rate'
:
self
.
task_config
.
stage_list
[
stage_id
].
end_learning_rate
,
'power'
:
1
,
'cycle'
:
False
,
}
},
warmup
=
{
'polynomial'
:
{
'warmup_steps'
:
self
.
task_config
.
stage_list
[
stage_id
].
warmup_steps
,
'power'
:
1
,
}
}
)
opt_factory
=
optimization
.
OptimizerFactory
(
params
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
return
optimizer
# overrides policies.ProgressivePolicy
def
get_train_dataset
(
self
,
stage_id
):
del
stage_id
if
self
.
_the_only_train_dataset
is
None
:
strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_the_only_train_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
self
.
build_inputs
,
self
.
task_config
.
train_data
)
return
self
.
_the_only_train_dataset
# overrides policies.ProgressivePolicy
def
get_eval_dataset
(
self
,
stage_id
):
del
stage_id
if
self
.
_the_only_eval_dataset
is
None
:
strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_the_only_eval_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
self
.
build_inputs
,
self
.
task_config
.
validation_data
)
return
self
.
_the_only_eval_dataset
def
_copy_weights_to_new_model
(
self
,
old_model
,
new_model
):
"""Copy model weights from the previous stage to the next.
Args:
old_model: nlp.modeling.models.bert_pretrainer.BertPretrainerV2. Model of
the previous stage.
new_model: nlp.modeling.models.bert_pretrainer.BertPretrainerV2. Model of
the next stage.
"""
# Copy weights of the embedding layers.
# pylint: disable=protected-access
# When using `encoder_scaffold`, there may be `_embedding_network`.
if
hasattr
(
new_model
.
encoder_network
,
'_embedding_network'
)
and
hasattr
(
old_model
.
encoder_network
,
'_embedding_network'
)
and
(
new_model
.
encoder_network
.
_embedding_network
is
not
None
):
new_model
.
encoder_network
.
_embedding_network
.
set_weights
(
old_model
.
encoder_network
.
_embedding_network
.
get_weights
())
else
:
new_model
.
encoder_network
.
_embedding_layer
.
set_weights
(
old_model
.
encoder_network
.
_embedding_layer
.
get_weights
())
new_model
.
encoder_network
.
_position_embedding_layer
.
set_weights
(
old_model
.
encoder_network
.
_position_embedding_layer
.
get_weights
())
new_model
.
encoder_network
.
_type_embedding_layer
.
set_weights
(
old_model
.
encoder_network
.
_type_embedding_layer
.
get_weights
())
new_model
.
encoder_network
.
_embedding_norm_layer
.
set_weights
(
old_model
.
encoder_network
.
_embedding_norm_layer
.
get_weights
())
if
hasattr
(
new_model
.
encoder_network
,
'_embedding_projection'
)
and
hasattr
(
old_model
.
encoder_network
,
'_embedding_projection'
):
if
old_model
.
encoder_network
.
_embedding_projection
is
not
None
:
new_model
.
encoder_network
.
_embedding_projection
.
set_weights
(
old_model
.
encoder_network
.
_embedding_projection
.
get_weights
())
# pylint: enable=protected-access
# Copy weights of the transformer layers.
# The model can be EncoderScaffold or TransformerEncoder.
if
hasattr
(
old_model
.
encoder_network
,
'hidden_layers'
):
old_layer_group
=
old_model
.
encoder_network
.
hidden_layers
elif
hasattr
(
old_model
.
encoder_network
,
'transformer_layers'
):
old_layer_group
=
old_model
.
encoder_network
.
transformer_layers
else
:
raise
ValueError
(
'Unrecognized encoder network: {}'
.
format
(
old_model
.
encoder_network
))
if
hasattr
(
new_model
.
encoder_network
,
'hidden_layers'
):
new_layer_group
=
new_model
.
encoder_network
.
hidden_layers
elif
hasattr
(
new_model
.
encoder_network
,
'transformer_layers'
):
new_layer_group
=
new_model
.
encoder_network
.
transformer_layers
else
:
raise
ValueError
(
'Unrecognized encoder network: {}'
.
format
(
new_model
.
encoder_network
))
for
new_layer_idx
in
range
(
len
(
new_layer_group
)):
old_layer_idx
=
new_layer_idx
%
len
(
old_layer_group
)
new_layer_group
[
new_layer_idx
].
set_weights
(
old_layer_group
[
old_layer_idx
].
get_weights
())
if
old_layer_idx
!=
new_layer_idx
:
if
hasattr
(
new_layer_group
[
new_layer_idx
],
'reset_rezero'
):
# Reset ReZero's alpha to 0.
new_layer_group
[
new_layer_idx
].
reset_rezero
()
# Copy weights of the final layer norm (if needed).
# pylint: disable=protected-access
if
hasattr
(
new_model
.
encoder_network
,
'_output_layer_norm'
)
and
hasattr
(
old_model
.
encoder_network
,
'_output_layer_norm'
):
new_model
.
encoder_network
.
_output_layer_norm
.
set_weights
(
old_model
.
encoder_network
.
_output_layer_norm
.
get_weights
())
# pylint: enable=protected-access
# Copy weights of the pooler layer.
new_model
.
encoder_network
.
pooler_layer
.
set_weights
(
old_model
.
encoder_network
.
pooler_layer
.
get_weights
())
# Copy weights of the classification head.
for
idx
in
range
(
len
(
new_model
.
classification_heads
)):
new_model
.
classification_heads
[
idx
].
set_weights
(
old_model
.
classification_heads
[
idx
].
get_weights
())
# Copy weights of the masked_lm layer.
new_model
.
masked_lm
.
set_weights
(
old_model
.
masked_lm
.
get_weights
())
official/nlp/tasks/progressive_masked_lm_test.py
0 → 100644
View file @
7f596d87
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for google.nlp.progressive_masked_lm."""
# Import libraries
from
absl.testing
import
parameterized
import
gin
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.core
import
config_definitions
as
cfg
from
official.modeling.progressive
import
trainer
as
prog_trainer_lib
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.tasks
import
progressive_masked_lm
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
"eager"
,
)
class
ProgressiveMaskedLMTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
ProgressiveMaskedLMTest
,
self
).
setUp
()
self
.
task_config
=
progressive_masked_lm
.
ProgMaskedLMConfig
(
model
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
2
)),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
]),
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
input_path
=
"dummy"
,
max_predictions_per_seq
=
20
,
seq_length
=
128
,
global_batch_size
=
1
),
validation_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
input_path
=
"dummy"
,
max_predictions_per_seq
=
20
,
seq_length
=
128
,
global_batch_size
=
1
),
stage_list
=
[
progressive_masked_lm
.
StackingStageConfig
(
num_layers
=
1
,
num_steps
=
4
),
progressive_masked_lm
.
StackingStageConfig
(
num_layers
=
2
,
num_steps
=
8
),
],
)
self
.
exp_config
=
cfg
.
ExperimentConfig
(
task
=
self
.
task_config
,
trainer
=
prog_trainer_lib
.
ProgressiveTrainerConfig
())
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_num_stages
(
self
,
distribution
):
with
distribution
.
scope
():
prog_masked_lm
=
progressive_masked_lm
.
ProgressiveMaskedLM
(
self
.
task_config
)
self
.
assertEqual
(
prog_masked_lm
.
num_stages
(),
2
)
self
.
assertEqual
(
prog_masked_lm
.
num_steps
(
0
),
4
)
self
.
assertEqual
(
prog_masked_lm
.
num_steps
(
1
),
8
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_weight_copying
(
self
,
distribution
):
with
distribution
.
scope
():
prog_masked_lm
=
progressive_masked_lm
.
ProgressiveMaskedLM
(
self
.
task_config
)
old_model
=
prog_masked_lm
.
get_model
(
stage_id
=
0
)
for
w
in
old_model
.
trainable_weights
:
w
.
assign
(
tf
.
zeros_like
(
w
)
+
0.12345
)
new_model
=
prog_masked_lm
.
get_model
(
stage_id
=
1
,
old_model
=
old_model
)
for
w
in
new_model
.
trainable_weights
:
self
.
assertAllClose
(
w
,
tf
.
zeros_like
(
w
)
+
0.12345
)
gin
.
parse_config_files_and_bindings
(
None
,
"encoders.build_encoder.encoder_cls = @EncoderScaffold"
)
with
distribution
.
scope
():
prog_masked_lm
=
progressive_masked_lm
.
ProgressiveMaskedLM
(
self
.
task_config
)
old_model
=
prog_masked_lm
.
get_model
(
stage_id
=
0
)
for
w
in
old_model
.
trainable_weights
:
w
.
assign
(
tf
.
zeros_like
(
w
)
+
0.12345
)
new_model
=
prog_masked_lm
.
get_model
(
stage_id
=
1
,
old_model
=
old_model
)
for
w
in
new_model
.
trainable_weights
:
self
.
assertAllClose
(
w
,
tf
.
zeros_like
(
w
)
+
0.12345
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/tasks/progressive_translation.py
0 → 100644
View file @
7f596d87
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Translation task with progressive training."""
from
typing
import
List
# Import libraries
from
absl
import
logging
import
dataclasses
import
orbit
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
task_factory
from
official.modeling
import
optimization
from
official.modeling.hyperparams
import
base_config
from
official.modeling.progressive
import
policies
from
official.nlp.modeling
import
models
from
official.nlp.tasks
import
translation
@
dataclasses
.
dataclass
class
StackingStageConfig
(
base_config
.
Config
):
num_encoder_layers
:
int
=
0
num_decoder_layers
:
int
=
0
num_steps
:
int
=
0
warmup_steps
:
int
=
10000
initial_learning_rate
:
float
=
0.0625
power
:
float
=
-
0.5
@
dataclasses
.
dataclass
class
ProgTranslationConfig
(
translation
.
TranslationConfig
):
"""The progressive model config."""
model
:
translation
.
ModelConfig
=
translation
.
ModelConfig
(
encoder
=
translation
.
EncDecoder
(
num_attention_heads
=
16
,
intermediate_size
=
4096
),
decoder
=
translation
.
EncDecoder
(
num_attention_heads
=
16
,
intermediate_size
=
4096
),
embedding_width
=
1024
,
padded_decode
=
True
,
decode_max_length
=
100
)
optimizer_config
:
optimization
.
OptimizationConfig
=
(
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adam'
,
'adam'
:
{
'beta_2'
:
0.997
,
'epsilon'
:
1e-9
,
},
},
'learning_rate'
:
{
'type'
:
'power'
,
'power'
:
{
'initial_learning_rate'
:
0.0625
,
'power'
:
-
0.5
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
16000
,
'warmup_learning_rate'
:
0.0
}
}
}))
stage_list
:
List
[
StackingStageConfig
]
=
dataclasses
.
field
(
default_factory
=
lambda
:
[
# pylint: disable=g-long-lambda
StackingStageConfig
(
num_encoder_layers
=
3
,
num_decoder_layers
=
3
,
num_steps
=
20000
,
warmup_steps
=
5000
,
initial_learning_rate
=
0.0625
),
StackingStageConfig
(
num_encoder_layers
=
6
,
num_decoder_layers
=
6
,
num_steps
=
20000
,
warmup_steps
=
5000
,
initial_learning_rate
=
0.0625
),
StackingStageConfig
(
num_encoder_layers
=
12
,
num_decoder_layers
=
12
,
num_steps
=
100000
,
warmup_steps
=
5000
,
initial_learning_rate
=
0.0625
)])
@
task_factory
.
register_task_cls
(
ProgTranslationConfig
)
class
ProgressiveTranslationTask
(
policies
.
ProgressivePolicy
,
translation
.
TranslationTask
):
"""Masked Language Model that supports progressive training.
Inherate from the TranslationTask class to build model datasets etc.
"""
def
__init__
(
self
,
params
:
cfg
.
TaskConfig
,
logging_dir
:
str
=
None
):
translation
.
TranslationTask
.
__init__
(
self
,
params
=
params
,
logging_dir
=
logging_dir
)
self
.
_model_config
=
params
.
model
self
.
_optimizer_config
=
params
.
optimizer_config
self
.
_the_only_train_dataset
=
None
self
.
_the_only_eval_dataset
=
None
policies
.
ProgressivePolicy
.
__init__
(
self
)
# Override
def
num_stages
(
self
):
return
len
(
self
.
task_config
.
stage_list
)
# Override
def
num_steps
(
self
,
stage_id
):
return
self
.
task_config
.
stage_list
[
stage_id
].
num_steps
# Override
def
get_model
(
self
,
stage_id
,
old_model
=
None
):
"""Build model for each stage."""
num_encoder_layers
=
(
self
.
task_config
.
stage_list
[
stage_id
].
num_encoder_layers
)
num_decoder_layers
=
(
self
.
task_config
.
stage_list
[
stage_id
].
num_decoder_layers
)
params
=
self
.
_model_config
.
replace
(
encoder
=
{
'num_layers'
:
num_encoder_layers
},
decoder
=
{
'num_layers'
:
num_decoder_layers
})
model
=
self
.
build_model
(
params
)
# Run the model once, to make sure that all layers are built.
# Otherwise, not all weights will be copied.
inputs
=
next
(
tf
.
nest
.
map_structure
(
iter
,
self
.
build_inputs
(
self
.
task_config
.
train_data
)))
model
(
inputs
,
training
=
True
)
if
stage_id
>
0
and
old_model
is
not
None
:
logging
.
info
(
'Stage %d copying weights.'
,
stage_id
)
self
.
_copy_weights_to_new_model
(
old_model
=
old_model
,
new_model
=
model
)
return
model
# Override
def
build_model
(
self
,
params
)
->
tf
.
keras
.
Model
:
"""Creates model architecture."""
model_cfg
=
params
or
self
.
task_config
.
model
encoder_kwargs
=
model_cfg
.
encoder
.
as_dict
()
encoder_layer
=
models
.
TransformerEncoder
(
**
encoder_kwargs
)
decoder_kwargs
=
model_cfg
.
decoder
.
as_dict
()
decoder_layer
=
models
.
TransformerDecoder
(
**
decoder_kwargs
)
return
models
.
Seq2SeqTransformer
(
vocab_size
=
self
.
_vocab_size
,
embedding_width
=
model_cfg
.
embedding_width
,
dropout_rate
=
model_cfg
.
dropout_rate
,
padded_decode
=
model_cfg
.
padded_decode
,
decode_max_length
=
model_cfg
.
decode_max_length
,
beam_size
=
model_cfg
.
beam_size
,
alpha
=
model_cfg
.
alpha
,
encoder_layer
=
encoder_layer
,
decoder_layer
=
decoder_layer
)
# Override
def
get_optimizer
(
self
,
stage_id
):
"""Build optimizer for each stage."""
params
=
self
.
_optimizer_config
.
replace
(
warmup
=
{
'linear'
:
{
'warmup_steps'
:
self
.
task_config
.
stage_list
[
stage_id
].
warmup_steps
},
},
learning_rate
=
{
'power'
:
{
'initial_learning_rate'
:
self
.
task_config
.
stage_list
[
stage_id
].
initial_learning_rate
},
},
)
opt_factory
=
optimization
.
OptimizerFactory
(
params
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
return
optimizer
# overrides policies.ProgressivePolicy
def
get_train_dataset
(
self
,
stage_id
):
del
stage_id
if
self
.
_the_only_train_dataset
is
None
:
strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_the_only_train_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
self
.
build_inputs
,
self
.
task_config
.
train_data
)
return
self
.
_the_only_train_dataset
# overrides policies.ProgressivePolicy
def
get_eval_dataset
(
self
,
stage_id
):
del
stage_id
if
self
.
_the_only_eval_dataset
is
None
:
strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_the_only_eval_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
self
.
build_inputs
,
self
.
task_config
.
validation_data
)
return
self
.
_the_only_eval_dataset
def
_copy_weights_to_new_model
(
self
,
old_model
,
new_model
):
"""Copy model weights from the previous stage to the next.
Args:
old_model: nlp.modeling.models.bert_pretrainer.BertPretrainerV2. Model of
the previous stage.
new_model: nlp.modeling.models.bert_pretrainer.BertPretrainerV2. Model of
the next stage.
"""
new_model
.
embedding_lookup
.
set_weights
(
old_model
.
embedding_lookup
.
get_weights
())
new_model
.
position_embedding
.
set_weights
(
old_model
.
position_embedding
.
get_weights
())
new_model
.
encoder_layer
.
output_normalization
.
set_weights
(
old_model
.
encoder_layer
.
output_normalization
.
get_weights
())
new_model
.
decoder_layer
.
output_normalization
.
set_weights
(
old_model
.
decoder_layer
.
output_normalization
.
get_weights
())
old_layer_group
=
old_model
.
encoder_layer
.
encoder_layers
new_layer_group
=
new_model
.
encoder_layer
.
encoder_layers
for
new_layer_idx
in
range
(
len
(
new_layer_group
)):
old_layer_idx
=
new_layer_idx
%
len
(
old_layer_group
)
new_layer_group
[
new_layer_idx
].
set_weights
(
old_layer_group
[
old_layer_idx
].
get_weights
())
old_layer_group
=
old_model
.
decoder_layer
.
decoder_layers
new_layer_group
=
new_model
.
decoder_layer
.
decoder_layers
for
new_layer_idx
in
range
(
len
(
new_layer_group
)):
old_layer_idx
=
new_layer_idx
%
len
(
old_layer_group
)
new_layer_group
[
new_layer_idx
].
set_weights
(
old_layer_group
[
old_layer_idx
].
get_weights
())
official/nlp/tasks/progressive_translation_test.py
0 → 100644
View file @
7f596d87
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for google.nlp.progressive_translation."""
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
sentencepiece
import
SentencePieceTrainer
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.core
import
config_definitions
as
cfg
from
official.modeling.progressive
import
trainer
as
prog_trainer_lib
from
official.nlp.data
import
wmt_dataloader
from
official.nlp.tasks
import
progressive_translation
from
official.nlp.tasks
import
translation
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
"eager"
,
)
def
_generate_line_file
(
filepath
,
lines
):
with
tf
.
io
.
gfile
.
GFile
(
filepath
,
"w"
)
as
f
:
for
l
in
lines
:
f
.
write
(
"{}
\n
"
.
format
(
l
))
def
_generate_record_file
(
filepath
,
src_lines
,
tgt_lines
):
writer
=
tf
.
io
.
TFRecordWriter
(
filepath
)
for
src
,
tgt
in
zip
(
src_lines
,
tgt_lines
):
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
"en"
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
src
.
encode
()])),
"reverse_en"
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
tgt
.
encode
()])),
}))
writer
.
write
(
example
.
SerializeToString
())
writer
.
close
()
def
_train_sentencepiece
(
input_path
,
vocab_size
,
model_path
,
eos_id
=
1
):
argstr
=
" "
.
join
([
f
"--input=
{
input_path
}
"
,
f
"--vocab_size=
{
vocab_size
}
"
,
"--character_coverage=0.995"
,
f
"--model_prefix=
{
model_path
}
"
,
"--model_type=bpe"
,
"--bos_id=-1"
,
"--pad_id=0"
,
f
"--eos_id=
{
eos_id
}
"
,
"--unk_id=2"
])
SentencePieceTrainer
.
Train
(
argstr
)
class
ProgressiveTranslationTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
ProgressiveTranslationTest
,
self
).
setUp
()
self
.
_temp_dir
=
self
.
get_temp_dir
()
src_lines
=
[
"abc ede fg"
,
"bbcd ef a g"
,
"de f a a g"
]
tgt_lines
=
[
"dd cc a ef g"
,
"bcd ef a g"
,
"gef cd ba"
]
self
.
_record_input_path
=
os
.
path
.
join
(
self
.
_temp_dir
,
"train.record"
)
_generate_record_file
(
self
.
_record_input_path
,
src_lines
,
tgt_lines
)
self
.
_sentencepeice_input_path
=
os
.
path
.
join
(
self
.
_temp_dir
,
"inputs.txt"
)
_generate_line_file
(
self
.
_sentencepeice_input_path
,
src_lines
+
tgt_lines
)
sentencepeice_model_prefix
=
os
.
path
.
join
(
self
.
_temp_dir
,
"sp"
)
_train_sentencepiece
(
self
.
_sentencepeice_input_path
,
11
,
sentencepeice_model_prefix
)
self
.
_sentencepeice_model_path
=
"{}.model"
.
format
(
sentencepeice_model_prefix
)
encdecoder
=
translation
.
EncDecoder
(
num_attention_heads
=
2
,
intermediate_size
=
8
)
self
.
task_config
=
progressive_translation
.
ProgTranslationConfig
(
model
=
translation
.
ModelConfig
(
encoder
=
encdecoder
,
decoder
=
encdecoder
,
embedding_width
=
8
,
padded_decode
=
True
,
decode_max_length
=
100
),
train_data
=
wmt_dataloader
.
WMTDataConfig
(
input_path
=
self
.
_record_input_path
,
is_training
=
True
,
global_batch_size
=
24
,
static_batch
=
True
,
src_lang
=
"en"
,
tgt_lang
=
"reverse_en"
,
max_seq_length
=
12
),
validation_data
=
wmt_dataloader
.
WMTDataConfig
(
input_path
=
self
.
_record_input_path
,
is_training
=
False
,
global_batch_size
=
2
,
static_batch
=
True
,
src_lang
=
"en"
,
tgt_lang
=
"reverse_en"
,
max_seq_length
=
12
),
sentencepiece_model_path
=
self
.
_sentencepeice_model_path
,
stage_list
=
[
progressive_translation
.
StackingStageConfig
(
num_encoder_layers
=
1
,
num_decoder_layers
=
1
,
num_steps
=
4
),
progressive_translation
.
StackingStageConfig
(
num_encoder_layers
=
2
,
num_decoder_layers
=
1
,
num_steps
=
8
),
],
)
self
.
exp_config
=
cfg
.
ExperimentConfig
(
task
=
self
.
task_config
,
trainer
=
prog_trainer_lib
.
ProgressiveTrainerConfig
())
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_num_stages
(
self
,
distribution
):
with
distribution
.
scope
():
prog_translation
=
progressive_translation
.
ProgressiveTranslationTask
(
self
.
task_config
)
self
.
assertEqual
(
prog_translation
.
num_stages
(),
2
)
self
.
assertEqual
(
prog_translation
.
num_steps
(
0
),
4
)
self
.
assertEqual
(
prog_translation
.
num_steps
(
1
),
8
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_weight_copying
(
self
,
distribution
):
with
distribution
.
scope
():
prog_translation
=
progressive_translation
.
ProgressiveTranslationTask
(
self
.
task_config
)
old_model
=
prog_translation
.
get_model
(
stage_id
=
0
)
for
w
in
old_model
.
trainable_weights
:
w
.
assign
(
tf
.
zeros_like
(
w
)
+
0.12345
)
new_model
=
prog_translation
.
get_model
(
stage_id
=
1
,
old_model
=
old_model
)
for
w
in
new_model
.
trainable_weights
:
self
.
assertAllClose
(
w
,
tf
.
zeros_like
(
w
)
+
0.12345
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment